Python3 Migrate

This commit is contained in:
MariuszC
2020-01-18 20:01:00 +01:00
parent ea05af2d15
commit 6cd7e0fe44
691 changed files with 201846 additions and 598 deletions

View File

@@ -0,0 +1,26 @@
#
# Copyright 2009 Facebook
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""The Tornado web server and tools."""
# version is a human-readable version number.
# version_info is a four-tuple for programmatic comparison. The first
# three numbers are the components of the version number. The fourth
# is zero for an official release, positive for a development branch,
# or negative for a release candidate or beta (after the base version
# number has been incremented)
version = "6.0.3"
version_info = (6, 0, 3, 0)

View File

@@ -0,0 +1,82 @@
# -*- coding: utf-8 -*-
#
# Copyright 2012 Facebook
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""Data used by the tornado.locale module."""
LOCALE_NAMES = {
"af_ZA": {"name_en": u"Afrikaans", "name": u"Afrikaans"},
"am_ET": {"name_en": u"Amharic", "name": u"አማርኛ"},
"ar_AR": {"name_en": u"Arabic", "name": u"العربية"},
"bg_BG": {"name_en": u"Bulgarian", "name": u"Български"},
"bn_IN": {"name_en": u"Bengali", "name": u"বাংলা"},
"bs_BA": {"name_en": u"Bosnian", "name": u"Bosanski"},
"ca_ES": {"name_en": u"Catalan", "name": u"Català"},
"cs_CZ": {"name_en": u"Czech", "name": u"Čeština"},
"cy_GB": {"name_en": u"Welsh", "name": u"Cymraeg"},
"da_DK": {"name_en": u"Danish", "name": u"Dansk"},
"de_DE": {"name_en": u"German", "name": u"Deutsch"},
"el_GR": {"name_en": u"Greek", "name": u"Ελληνικά"},
"en_GB": {"name_en": u"English (UK)", "name": u"English (UK)"},
"en_US": {"name_en": u"English (US)", "name": u"English (US)"},
"es_ES": {"name_en": u"Spanish (Spain)", "name": u"Español (España)"},
"es_LA": {"name_en": u"Spanish", "name": u"Español"},
"et_EE": {"name_en": u"Estonian", "name": u"Eesti"},
"eu_ES": {"name_en": u"Basque", "name": u"Euskara"},
"fa_IR": {"name_en": u"Persian", "name": u"فارسی"},
"fi_FI": {"name_en": u"Finnish", "name": u"Suomi"},
"fr_CA": {"name_en": u"French (Canada)", "name": u"Français (Canada)"},
"fr_FR": {"name_en": u"French", "name": u"Français"},
"ga_IE": {"name_en": u"Irish", "name": u"Gaeilge"},
"gl_ES": {"name_en": u"Galician", "name": u"Galego"},
"he_IL": {"name_en": u"Hebrew", "name": u"עברית"},
"hi_IN": {"name_en": u"Hindi", "name": u"हिन्दी"},
"hr_HR": {"name_en": u"Croatian", "name": u"Hrvatski"},
"hu_HU": {"name_en": u"Hungarian", "name": u"Magyar"},
"id_ID": {"name_en": u"Indonesian", "name": u"Bahasa Indonesia"},
"is_IS": {"name_en": u"Icelandic", "name": u"Íslenska"},
"it_IT": {"name_en": u"Italian", "name": u"Italiano"},
"ja_JP": {"name_en": u"Japanese", "name": u"日本語"},
"ko_KR": {"name_en": u"Korean", "name": u"한국어"},
"lt_LT": {"name_en": u"Lithuanian", "name": u"Lietuvių"},
"lv_LV": {"name_en": u"Latvian", "name": u"Latviešu"},
"mk_MK": {"name_en": u"Macedonian", "name": u"Македонски"},
"ml_IN": {"name_en": u"Malayalam", "name": u"മലയാളം"},
"ms_MY": {"name_en": u"Malay", "name": u"Bahasa Melayu"},
"nb_NO": {"name_en": u"Norwegian (bokmal)", "name": u"Norsk (bokmål)"},
"nl_NL": {"name_en": u"Dutch", "name": u"Nederlands"},
"nn_NO": {"name_en": u"Norwegian (nynorsk)", "name": u"Norsk (nynorsk)"},
"pa_IN": {"name_en": u"Punjabi", "name": u"ਪੰਜਾਬੀ"},
"pl_PL": {"name_en": u"Polish", "name": u"Polski"},
"pt_BR": {"name_en": u"Portuguese (Brazil)", "name": u"Português (Brasil)"},
"pt_PT": {"name_en": u"Portuguese (Portugal)", "name": u"Português (Portugal)"},
"ro_RO": {"name_en": u"Romanian", "name": u"Română"},
"ru_RU": {"name_en": u"Russian", "name": u"Русский"},
"sk_SK": {"name_en": u"Slovak", "name": u"Slovenčina"},
"sl_SI": {"name_en": u"Slovenian", "name": u"Slovenščina"},
"sq_AL": {"name_en": u"Albanian", "name": u"Shqip"},
"sr_RS": {"name_en": u"Serbian", "name": u"Српски"},
"sv_SE": {"name_en": u"Swedish", "name": u"Svenska"},
"sw_KE": {"name_en": u"Swahili", "name": u"Kiswahili"},
"ta_IN": {"name_en": u"Tamil", "name": u"தமிழ்"},
"te_IN": {"name_en": u"Telugu", "name": u"తెలుగు"},
"th_TH": {"name_en": u"Thai", "name": u"ภาษาไทย"},
"tl_PH": {"name_en": u"Filipino", "name": u"Filipino"},
"tr_TR": {"name_en": u"Turkish", "name": u"Türkçe"},
"uk_UA": {"name_en": u"Ukraini ", "name": u"Українська"},
"vi_VN": {"name_en": u"Vietnamese", "name": u"Tiếng Việt"},
"zh_CN": {"name_en": u"Chinese (Simplified)", "name": u"中文(简体)"},
"zh_TW": {"name_en": u"Chinese (Traditional)", "name": u"中文(繁體)"},
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,364 @@
#
# Copyright 2009 Facebook
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""Automatically restart the server when a source file is modified.
Most applications should not access this module directly. Instead,
pass the keyword argument ``autoreload=True`` to the
`tornado.web.Application` constructor (or ``debug=True``, which
enables this setting and several others). This will enable autoreload
mode as well as checking for changes to templates and static
resources. Note that restarting is a destructive operation and any
requests in progress will be aborted when the process restarts. (If
you want to disable autoreload while using other debug-mode features,
pass both ``debug=True`` and ``autoreload=False``).
This module can also be used as a command-line wrapper around scripts
such as unit test runners. See the `main` method for details.
The command-line wrapper and Application debug modes can be used together.
This combination is encouraged as the wrapper catches syntax errors and
other import-time failures, while debug mode catches changes once
the server has started.
This module will not work correctly when `.HTTPServer`'s multi-process
mode is used.
Reloading loses any Python interpreter command-line arguments (e.g. ``-u``)
because it re-executes Python using ``sys.executable`` and ``sys.argv``.
Additionally, modifying these variables will cause reloading to behave
incorrectly.
"""
import os
import sys
# sys.path handling
# -----------------
#
# If a module is run with "python -m", the current directory (i.e. "")
# is automatically prepended to sys.path, but not if it is run as
# "path/to/file.py". The processing for "-m" rewrites the former to
# the latter, so subsequent executions won't have the same path as the
# original.
#
# Conversely, when run as path/to/file.py, the directory containing
# file.py gets added to the path, which can cause confusion as imports
# may become relative in spite of the future import.
#
# We address the former problem by reconstructing the original command
# line (Python >= 3.4) or by setting the $PYTHONPATH environment
# variable (Python < 3.4) before re-execution so the new process will
# see the correct path. We attempt to address the latter problem when
# tornado.autoreload is run as __main__.
if __name__ == "__main__":
# This sys.path manipulation must come before our imports (as much
# as possible - if we introduced a tornado.sys or tornado.os
# module we'd be in trouble), or else our imports would become
# relative again despite the future import.
#
# There is a separate __main__ block at the end of the file to call main().
if sys.path[0] == os.path.dirname(__file__):
del sys.path[0]
import functools
import logging
import os
import pkgutil # type: ignore
import sys
import traceback
import types
import subprocess
import weakref
from tornado import ioloop
from tornado.log import gen_log
from tornado import process
from tornado.util import exec_in
try:
import signal
except ImportError:
signal = None # type: ignore
import typing
from typing import Callable, Dict
if typing.TYPE_CHECKING:
from typing import List, Optional, Union # noqa: F401
# os.execv is broken on Windows and can't properly parse command line
# arguments and executable name if they contain whitespaces. subprocess
# fixes that behavior.
_has_execv = sys.platform != "win32"
_watched_files = set()
_reload_hooks = []
_reload_attempted = False
_io_loops = weakref.WeakKeyDictionary() # type: ignore
_autoreload_is_main = False
_original_argv = None # type: Optional[List[str]]
_original_spec = None
def start(check_time: int = 500) -> None:
"""Begins watching source files for changes.
.. versionchanged:: 5.0
The ``io_loop`` argument (deprecated since version 4.1) has been removed.
"""
io_loop = ioloop.IOLoop.current()
if io_loop in _io_loops:
return
_io_loops[io_loop] = True
if len(_io_loops) > 1:
gen_log.warning("tornado.autoreload started more than once in the same process")
modify_times = {} # type: Dict[str, float]
callback = functools.partial(_reload_on_update, modify_times)
scheduler = ioloop.PeriodicCallback(callback, check_time)
scheduler.start()
def wait() -> None:
"""Wait for a watched file to change, then restart the process.
Intended to be used at the end of scripts like unit test runners,
to run the tests again after any source file changes (but see also
the command-line interface in `main`)
"""
io_loop = ioloop.IOLoop()
io_loop.add_callback(start)
io_loop.start()
def watch(filename: str) -> None:
"""Add a file to the watch list.
All imported modules are watched by default.
"""
_watched_files.add(filename)
def add_reload_hook(fn: Callable[[], None]) -> None:
"""Add a function to be called before reloading the process.
Note that for open file and socket handles it is generally
preferable to set the ``FD_CLOEXEC`` flag (using `fcntl` or
``tornado.platform.auto.set_close_exec``) instead
of using a reload hook to close them.
"""
_reload_hooks.append(fn)
def _reload_on_update(modify_times: Dict[str, float]) -> None:
if _reload_attempted:
# We already tried to reload and it didn't work, so don't try again.
return
if process.task_id() is not None:
# We're in a child process created by fork_processes. If child
# processes restarted themselves, they'd all restart and then
# all call fork_processes again.
return
for module in list(sys.modules.values()):
# Some modules play games with sys.modules (e.g. email/__init__.py
# in the standard library), and occasionally this can cause strange
# failures in getattr. Just ignore anything that's not an ordinary
# module.
if not isinstance(module, types.ModuleType):
continue
path = getattr(module, "__file__", None)
if not path:
continue
if path.endswith(".pyc") or path.endswith(".pyo"):
path = path[:-1]
_check_file(modify_times, path)
for path in _watched_files:
_check_file(modify_times, path)
def _check_file(modify_times: Dict[str, float], path: str) -> None:
try:
modified = os.stat(path).st_mtime
except Exception:
return
if path not in modify_times:
modify_times[path] = modified
return
if modify_times[path] != modified:
gen_log.info("%s modified; restarting server", path)
_reload()
def _reload() -> None:
global _reload_attempted
_reload_attempted = True
for fn in _reload_hooks:
fn()
if hasattr(signal, "setitimer"):
# Clear the alarm signal set by
# ioloop.set_blocking_log_threshold so it doesn't fire
# after the exec.
signal.setitimer(signal.ITIMER_REAL, 0, 0)
# sys.path fixes: see comments at top of file. If __main__.__spec__
# exists, we were invoked with -m and the effective path is about to
# change on re-exec. Reconstruct the original command line to
# ensure that the new process sees the same path we did. If
# __spec__ is not available (Python < 3.4), check instead if
# sys.path[0] is an empty string and add the current directory to
# $PYTHONPATH.
if _autoreload_is_main:
assert _original_argv is not None
spec = _original_spec
argv = _original_argv
else:
spec = getattr(sys.modules["__main__"], "__spec__", None)
argv = sys.argv
if spec:
argv = ["-m", spec.name] + argv[1:]
else:
path_prefix = "." + os.pathsep
if sys.path[0] == "" and not os.environ.get("PYTHONPATH", "").startswith(
path_prefix
):
os.environ["PYTHONPATH"] = path_prefix + os.environ.get("PYTHONPATH", "")
if not _has_execv:
subprocess.Popen([sys.executable] + argv)
os._exit(0)
else:
try:
os.execv(sys.executable, [sys.executable] + argv)
except OSError:
# Mac OS X versions prior to 10.6 do not support execv in
# a process that contains multiple threads. Instead of
# re-executing in the current process, start a new one
# and cause the current process to exit. This isn't
# ideal since the new process is detached from the parent
# terminal and thus cannot easily be killed with ctrl-C,
# but it's better than not being able to autoreload at
# all.
# Unfortunately the errno returned in this case does not
# appear to be consistent, so we can't easily check for
# this error specifically.
os.spawnv( # type: ignore
os.P_NOWAIT, sys.executable, [sys.executable] + argv
)
# At this point the IOLoop has been closed and finally
# blocks will experience errors if we allow the stack to
# unwind, so just exit uncleanly.
os._exit(0)
_USAGE = """\
Usage:
python -m tornado.autoreload -m module.to.run [args...]
python -m tornado.autoreload path/to/script.py [args...]
"""
def main() -> None:
"""Command-line wrapper to re-run a script whenever its source changes.
Scripts may be specified by filename or module name::
python -m tornado.autoreload -m tornado.test.runtests
python -m tornado.autoreload tornado/test/runtests.py
Running a script with this wrapper is similar to calling
`tornado.autoreload.wait` at the end of the script, but this wrapper
can catch import-time problems like syntax errors that would otherwise
prevent the script from reaching its call to `wait`.
"""
# Remember that we were launched with autoreload as main.
# The main module can be tricky; set the variables both in our globals
# (which may be __main__) and the real importable version.
import tornado.autoreload
global _autoreload_is_main
global _original_argv, _original_spec
tornado.autoreload._autoreload_is_main = _autoreload_is_main = True
original_argv = sys.argv
tornado.autoreload._original_argv = _original_argv = original_argv
original_spec = getattr(sys.modules["__main__"], "__spec__", None)
tornado.autoreload._original_spec = _original_spec = original_spec
sys.argv = sys.argv[:]
if len(sys.argv) >= 3 and sys.argv[1] == "-m":
mode = "module"
module = sys.argv[2]
del sys.argv[1:3]
elif len(sys.argv) >= 2:
mode = "script"
script = sys.argv[1]
sys.argv = sys.argv[1:]
else:
print(_USAGE, file=sys.stderr)
sys.exit(1)
try:
if mode == "module":
import runpy
runpy.run_module(module, run_name="__main__", alter_sys=True)
elif mode == "script":
with open(script) as f:
# Execute the script in our namespace instead of creating
# a new one so that something that tries to import __main__
# (e.g. the unittest module) will see names defined in the
# script instead of just those defined in this module.
global __file__
__file__ = script
# If __package__ is defined, imports may be incorrectly
# interpreted as relative to this module.
global __package__
del __package__
exec_in(f.read(), globals(), globals())
except SystemExit as e:
logging.basicConfig()
gen_log.info("Script exited with status %s", e.code)
except Exception as e:
logging.basicConfig()
gen_log.warning("Script exited with uncaught exception", exc_info=True)
# If an exception occurred at import time, the file with the error
# never made it into sys.modules and so we won't know to watch it.
# Just to make sure we've covered everything, walk the stack trace
# from the exception and watch every file.
for (filename, lineno, name, line) in traceback.extract_tb(sys.exc_info()[2]):
watch(filename)
if isinstance(e, SyntaxError):
# SyntaxErrors are special: their innermost stack frame is fake
# so extract_tb won't see it and we have to get the filename
# from the exception object.
watch(e.filename)
else:
logging.basicConfig()
gen_log.info("Script exited normally")
# restore sys.argv so subsequent executions will include autoreload
sys.argv = original_argv
if mode == "module":
# runpy did a fake import of the module as __main__, but now it's
# no longer in sys.modules. Figure out where it is and watch it.
loader = pkgutil.get_loader(module)
if loader is not None:
watch(loader.get_filename()) # type: ignore
wait()
if __name__ == "__main__":
# See also the other __main__ block at the top of the file, which modifies
# sys.path before our imports
main()

View File

@@ -0,0 +1,264 @@
#
# Copyright 2012 Facebook
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""Utilities for working with ``Future`` objects.
Tornado previously provided its own ``Future`` class, but now uses
`asyncio.Future`. This module contains utility functions for working
with `asyncio.Future` in a way that is backwards-compatible with
Tornado's old ``Future`` implementation.
While this module is an important part of Tornado's internal
implementation, applications rarely need to interact with it
directly.
"""
import asyncio
from concurrent import futures
import functools
import sys
import types
from tornado.log import app_log
import typing
from typing import Any, Callable, Optional, Tuple, Union
_T = typing.TypeVar("_T")
class ReturnValueIgnoredError(Exception):
# No longer used; was previously used by @return_future
pass
Future = asyncio.Future
FUTURES = (futures.Future, Future)
def is_future(x: Any) -> bool:
return isinstance(x, FUTURES)
class DummyExecutor(futures.Executor):
def submit(
self, fn: Callable[..., _T], *args: Any, **kwargs: Any
) -> "futures.Future[_T]":
future = futures.Future() # type: futures.Future[_T]
try:
future_set_result_unless_cancelled(future, fn(*args, **kwargs))
except Exception:
future_set_exc_info(future, sys.exc_info())
return future
def shutdown(self, wait: bool = True) -> None:
pass
dummy_executor = DummyExecutor()
def run_on_executor(*args: Any, **kwargs: Any) -> Callable:
"""Decorator to run a synchronous method asynchronously on an executor.
The decorated method may be called with a ``callback`` keyword
argument and returns a future.
The executor to be used is determined by the ``executor``
attributes of ``self``. To use a different attribute name, pass a
keyword argument to the decorator::
@run_on_executor(executor='_thread_pool')
def foo(self):
pass
This decorator should not be confused with the similarly-named
`.IOLoop.run_in_executor`. In general, using ``run_in_executor``
when *calling* a blocking method is recommended instead of using
this decorator when *defining* a method. If compatibility with older
versions of Tornado is required, consider defining an executor
and using ``executor.submit()`` at the call site.
.. versionchanged:: 4.2
Added keyword arguments to use alternative attributes.
.. versionchanged:: 5.0
Always uses the current IOLoop instead of ``self.io_loop``.
.. versionchanged:: 5.1
Returns a `.Future` compatible with ``await`` instead of a
`concurrent.futures.Future`.
.. deprecated:: 5.1
The ``callback`` argument is deprecated and will be removed in
6.0. The decorator itself is discouraged in new code but will
not be removed in 6.0.
.. versionchanged:: 6.0
The ``callback`` argument was removed.
"""
# Fully type-checking decorators is tricky, and this one is
# discouraged anyway so it doesn't have all the generic magic.
def run_on_executor_decorator(fn: Callable) -> Callable[..., Future]:
executor = kwargs.get("executor", "executor")
@functools.wraps(fn)
def wrapper(self: Any, *args: Any, **kwargs: Any) -> Future:
async_future = Future() # type: Future
conc_future = getattr(self, executor).submit(fn, self, *args, **kwargs)
chain_future(conc_future, async_future)
return async_future
return wrapper
if args and kwargs:
raise ValueError("cannot combine positional and keyword args")
if len(args) == 1:
return run_on_executor_decorator(args[0])
elif len(args) != 0:
raise ValueError("expected 1 argument, got %d", len(args))
return run_on_executor_decorator
_NO_RESULT = object()
def chain_future(a: "Future[_T]", b: "Future[_T]") -> None:
"""Chain two futures together so that when one completes, so does the other.
The result (success or failure) of ``a`` will be copied to ``b``, unless
``b`` has already been completed or cancelled by the time ``a`` finishes.
.. versionchanged:: 5.0
Now accepts both Tornado/asyncio `Future` objects and
`concurrent.futures.Future`.
"""
def copy(future: "Future[_T]") -> None:
assert future is a
if b.done():
return
if hasattr(a, "exc_info") and a.exc_info() is not None: # type: ignore
future_set_exc_info(b, a.exc_info()) # type: ignore
elif a.exception() is not None:
b.set_exception(a.exception())
else:
b.set_result(a.result())
if isinstance(a, Future):
future_add_done_callback(a, copy)
else:
# concurrent.futures.Future
from tornado.ioloop import IOLoop
IOLoop.current().add_future(a, copy)
def future_set_result_unless_cancelled(
future: "Union[futures.Future[_T], Future[_T]]", value: _T
) -> None:
"""Set the given ``value`` as the `Future`'s result, if not cancelled.
Avoids ``asyncio.InvalidStateError`` when calling ``set_result()`` on
a cancelled `asyncio.Future`.
.. versionadded:: 5.0
"""
if not future.cancelled():
future.set_result(value)
def future_set_exception_unless_cancelled(
future: "Union[futures.Future[_T], Future[_T]]", exc: BaseException
) -> None:
"""Set the given ``exc`` as the `Future`'s exception.
If the Future is already canceled, logs the exception instead. If
this logging is not desired, the caller should explicitly check
the state of the Future and call ``Future.set_exception`` instead of
this wrapper.
Avoids ``asyncio.InvalidStateError`` when calling ``set_exception()`` on
a cancelled `asyncio.Future`.
.. versionadded:: 6.0
"""
if not future.cancelled():
future.set_exception(exc)
else:
app_log.error("Exception after Future was cancelled", exc_info=exc)
def future_set_exc_info(
future: "Union[futures.Future[_T], Future[_T]]",
exc_info: Tuple[
Optional[type], Optional[BaseException], Optional[types.TracebackType]
],
) -> None:
"""Set the given ``exc_info`` as the `Future`'s exception.
Understands both `asyncio.Future` and the extensions in older
versions of Tornado to enable better tracebacks on Python 2.
.. versionadded:: 5.0
.. versionchanged:: 6.0
If the future is already cancelled, this function is a no-op.
(previously ``asyncio.InvalidStateError`` would be raised)
"""
if exc_info[1] is None:
raise Exception("future_set_exc_info called with no exception")
future_set_exception_unless_cancelled(future, exc_info[1])
@typing.overload
def future_add_done_callback(
future: "futures.Future[_T]", callback: Callable[["futures.Future[_T]"], None]
) -> None:
pass
@typing.overload # noqa: F811
def future_add_done_callback(
future: "Future[_T]", callback: Callable[["Future[_T]"], None]
) -> None:
pass
def future_add_done_callback( # noqa: F811
future: "Union[futures.Future[_T], Future[_T]]", callback: Callable[..., None]
) -> None:
"""Arrange to call ``callback`` when ``future`` is complete.
``callback`` is invoked with one argument, the ``future``.
If ``future`` is already done, ``callback`` is invoked immediately.
This may differ from the behavior of ``Future.add_done_callback``,
which makes no such guarantee.
.. versionadded:: 5.0
"""
if future.done():
callback(future)
else:
future.add_done_callback(callback)

View File

@@ -0,0 +1,572 @@
#
# Copyright 2009 Facebook
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""Non-blocking HTTP client implementation using pycurl."""
import collections
import functools
import logging
import pycurl # type: ignore
import threading
import time
from io import BytesIO
from tornado import httputil
from tornado import ioloop
from tornado.escape import utf8, native_str
from tornado.httpclient import (
HTTPRequest,
HTTPResponse,
HTTPError,
AsyncHTTPClient,
main,
)
from tornado.log import app_log
from typing import Dict, Any, Callable, Union
import typing
if typing.TYPE_CHECKING:
from typing import Deque, Tuple, Optional # noqa: F401
curl_log = logging.getLogger("tornado.curl_httpclient")
class CurlAsyncHTTPClient(AsyncHTTPClient):
def initialize( # type: ignore
self, max_clients: int = 10, defaults: Dict[str, Any] = None
) -> None:
super(CurlAsyncHTTPClient, self).initialize(defaults=defaults)
self._multi = pycurl.CurlMulti()
self._multi.setopt(pycurl.M_TIMERFUNCTION, self._set_timeout)
self._multi.setopt(pycurl.M_SOCKETFUNCTION, self._handle_socket)
self._curls = [self._curl_create() for i in range(max_clients)]
self._free_list = self._curls[:]
self._requests = (
collections.deque()
) # type: Deque[Tuple[HTTPRequest, Callable[[HTTPResponse], None], float]]
self._fds = {} # type: Dict[int, int]
self._timeout = None # type: Optional[object]
# libcurl has bugs that sometimes cause it to not report all
# relevant file descriptors and timeouts to TIMERFUNCTION/
# SOCKETFUNCTION. Mitigate the effects of such bugs by
# forcing a periodic scan of all active requests.
self._force_timeout_callback = ioloop.PeriodicCallback(
self._handle_force_timeout, 1000
)
self._force_timeout_callback.start()
# Work around a bug in libcurl 7.29.0: Some fields in the curl
# multi object are initialized lazily, and its destructor will
# segfault if it is destroyed without having been used. Add
# and remove a dummy handle to make sure everything is
# initialized.
dummy_curl_handle = pycurl.Curl()
self._multi.add_handle(dummy_curl_handle)
self._multi.remove_handle(dummy_curl_handle)
def close(self) -> None:
self._force_timeout_callback.stop()
if self._timeout is not None:
self.io_loop.remove_timeout(self._timeout)
for curl in self._curls:
curl.close()
self._multi.close()
super(CurlAsyncHTTPClient, self).close()
# Set below properties to None to reduce the reference count of current
# instance, because those properties hold some methods of current
# instance that will case circular reference.
self._force_timeout_callback = None # type: ignore
self._multi = None
def fetch_impl(
self, request: HTTPRequest, callback: Callable[[HTTPResponse], None]
) -> None:
self._requests.append((request, callback, self.io_loop.time()))
self._process_queue()
self._set_timeout(0)
def _handle_socket(self, event: int, fd: int, multi: Any, data: bytes) -> None:
"""Called by libcurl when it wants to change the file descriptors
it cares about.
"""
event_map = {
pycurl.POLL_NONE: ioloop.IOLoop.NONE,
pycurl.POLL_IN: ioloop.IOLoop.READ,
pycurl.POLL_OUT: ioloop.IOLoop.WRITE,
pycurl.POLL_INOUT: ioloop.IOLoop.READ | ioloop.IOLoop.WRITE,
}
if event == pycurl.POLL_REMOVE:
if fd in self._fds:
self.io_loop.remove_handler(fd)
del self._fds[fd]
else:
ioloop_event = event_map[event]
# libcurl sometimes closes a socket and then opens a new
# one using the same FD without giving us a POLL_NONE in
# between. This is a problem with the epoll IOLoop,
# because the kernel can tell when a socket is closed and
# removes it from the epoll automatically, causing future
# update_handler calls to fail. Since we can't tell when
# this has happened, always use remove and re-add
# instead of update.
if fd in self._fds:
self.io_loop.remove_handler(fd)
self.io_loop.add_handler(fd, self._handle_events, ioloop_event)
self._fds[fd] = ioloop_event
def _set_timeout(self, msecs: int) -> None:
"""Called by libcurl to schedule a timeout."""
if self._timeout is not None:
self.io_loop.remove_timeout(self._timeout)
self._timeout = self.io_loop.add_timeout(
self.io_loop.time() + msecs / 1000.0, self._handle_timeout
)
def _handle_events(self, fd: int, events: int) -> None:
"""Called by IOLoop when there is activity on one of our
file descriptors.
"""
action = 0
if events & ioloop.IOLoop.READ:
action |= pycurl.CSELECT_IN
if events & ioloop.IOLoop.WRITE:
action |= pycurl.CSELECT_OUT
while True:
try:
ret, num_handles = self._multi.socket_action(fd, action)
except pycurl.error as e:
ret = e.args[0]
if ret != pycurl.E_CALL_MULTI_PERFORM:
break
self._finish_pending_requests()
def _handle_timeout(self) -> None:
"""Called by IOLoop when the requested timeout has passed."""
self._timeout = None
while True:
try:
ret, num_handles = self._multi.socket_action(pycurl.SOCKET_TIMEOUT, 0)
except pycurl.error as e:
ret = e.args[0]
if ret != pycurl.E_CALL_MULTI_PERFORM:
break
self._finish_pending_requests()
# In theory, we shouldn't have to do this because curl will
# call _set_timeout whenever the timeout changes. However,
# sometimes after _handle_timeout we will need to reschedule
# immediately even though nothing has changed from curl's
# perspective. This is because when socket_action is
# called with SOCKET_TIMEOUT, libcurl decides internally which
# timeouts need to be processed by using a monotonic clock
# (where available) while tornado uses python's time.time()
# to decide when timeouts have occurred. When those clocks
# disagree on elapsed time (as they will whenever there is an
# NTP adjustment), tornado might call _handle_timeout before
# libcurl is ready. After each timeout, resync the scheduled
# timeout with libcurl's current state.
new_timeout = self._multi.timeout()
if new_timeout >= 0:
self._set_timeout(new_timeout)
def _handle_force_timeout(self) -> None:
"""Called by IOLoop periodically to ask libcurl to process any
events it may have forgotten about.
"""
while True:
try:
ret, num_handles = self._multi.socket_all()
except pycurl.error as e:
ret = e.args[0]
if ret != pycurl.E_CALL_MULTI_PERFORM:
break
self._finish_pending_requests()
def _finish_pending_requests(self) -> None:
"""Process any requests that were completed by the last
call to multi.socket_action.
"""
while True:
num_q, ok_list, err_list = self._multi.info_read()
for curl in ok_list:
self._finish(curl)
for curl, errnum, errmsg in err_list:
self._finish(curl, errnum, errmsg)
if num_q == 0:
break
self._process_queue()
def _process_queue(self) -> None:
while True:
started = 0
while self._free_list and self._requests:
started += 1
curl = self._free_list.pop()
(request, callback, queue_start_time) = self._requests.popleft()
curl.info = {
"headers": httputil.HTTPHeaders(),
"buffer": BytesIO(),
"request": request,
"callback": callback,
"queue_start_time": queue_start_time,
"curl_start_time": time.time(),
"curl_start_ioloop_time": self.io_loop.current().time(),
}
try:
self._curl_setup_request(
curl, request, curl.info["buffer"], curl.info["headers"]
)
except Exception as e:
# If there was an error in setup, pass it on
# to the callback. Note that allowing the
# error to escape here will appear to work
# most of the time since we are still in the
# caller's original stack frame, but when
# _process_queue() is called from
# _finish_pending_requests the exceptions have
# nowhere to go.
self._free_list.append(curl)
callback(HTTPResponse(request=request, code=599, error=e))
else:
self._multi.add_handle(curl)
if not started:
break
def _finish(
self, curl: pycurl.Curl, curl_error: int = None, curl_message: str = None
) -> None:
info = curl.info
curl.info = None
self._multi.remove_handle(curl)
self._free_list.append(curl)
buffer = info["buffer"]
if curl_error:
assert curl_message is not None
error = CurlError(curl_error, curl_message) # type: Optional[CurlError]
assert error is not None
code = error.code
effective_url = None
buffer.close()
buffer = None
else:
error = None
code = curl.getinfo(pycurl.HTTP_CODE)
effective_url = curl.getinfo(pycurl.EFFECTIVE_URL)
buffer.seek(0)
# the various curl timings are documented at
# http://curl.haxx.se/libcurl/c/curl_easy_getinfo.html
time_info = dict(
queue=info["curl_start_ioloop_time"] - info["queue_start_time"],
namelookup=curl.getinfo(pycurl.NAMELOOKUP_TIME),
connect=curl.getinfo(pycurl.CONNECT_TIME),
appconnect=curl.getinfo(pycurl.APPCONNECT_TIME),
pretransfer=curl.getinfo(pycurl.PRETRANSFER_TIME),
starttransfer=curl.getinfo(pycurl.STARTTRANSFER_TIME),
total=curl.getinfo(pycurl.TOTAL_TIME),
redirect=curl.getinfo(pycurl.REDIRECT_TIME),
)
try:
info["callback"](
HTTPResponse(
request=info["request"],
code=code,
headers=info["headers"],
buffer=buffer,
effective_url=effective_url,
error=error,
reason=info["headers"].get("X-Http-Reason", None),
request_time=self.io_loop.time() - info["curl_start_ioloop_time"],
start_time=info["curl_start_time"],
time_info=time_info,
)
)
except Exception:
self.handle_callback_exception(info["callback"])
def handle_callback_exception(self, callback: Any) -> None:
app_log.error("Exception in callback %r", callback, exc_info=True)
def _curl_create(self) -> pycurl.Curl:
curl = pycurl.Curl()
if curl_log.isEnabledFor(logging.DEBUG):
curl.setopt(pycurl.VERBOSE, 1)
curl.setopt(pycurl.DEBUGFUNCTION, self._curl_debug)
if hasattr(
pycurl, "PROTOCOLS"
): # PROTOCOLS first appeared in pycurl 7.19.5 (2014-07-12)
curl.setopt(pycurl.PROTOCOLS, pycurl.PROTO_HTTP | pycurl.PROTO_HTTPS)
curl.setopt(pycurl.REDIR_PROTOCOLS, pycurl.PROTO_HTTP | pycurl.PROTO_HTTPS)
return curl
def _curl_setup_request(
self,
curl: pycurl.Curl,
request: HTTPRequest,
buffer: BytesIO,
headers: httputil.HTTPHeaders,
) -> None:
curl.setopt(pycurl.URL, native_str(request.url))
# libcurl's magic "Expect: 100-continue" behavior causes delays
# with servers that don't support it (which include, among others,
# Google's OpenID endpoint). Additionally, this behavior has
# a bug in conjunction with the curl_multi_socket_action API
# (https://sourceforge.net/tracker/?func=detail&atid=100976&aid=3039744&group_id=976),
# which increases the delays. It's more trouble than it's worth,
# so just turn off the feature (yes, setting Expect: to an empty
# value is the official way to disable this)
if "Expect" not in request.headers:
request.headers["Expect"] = ""
# libcurl adds Pragma: no-cache by default; disable that too
if "Pragma" not in request.headers:
request.headers["Pragma"] = ""
curl.setopt(
pycurl.HTTPHEADER,
[
"%s: %s" % (native_str(k), native_str(v))
for k, v in request.headers.get_all()
],
)
curl.setopt(
pycurl.HEADERFUNCTION,
functools.partial(
self._curl_header_callback, headers, request.header_callback
),
)
if request.streaming_callback:
def write_function(b: Union[bytes, bytearray]) -> int:
assert request.streaming_callback is not None
self.io_loop.add_callback(request.streaming_callback, b)
return len(b)
else:
write_function = buffer.write
curl.setopt(pycurl.WRITEFUNCTION, write_function)
curl.setopt(pycurl.FOLLOWLOCATION, request.follow_redirects)
curl.setopt(pycurl.MAXREDIRS, request.max_redirects)
assert request.connect_timeout is not None
curl.setopt(pycurl.CONNECTTIMEOUT_MS, int(1000 * request.connect_timeout))
assert request.request_timeout is not None
curl.setopt(pycurl.TIMEOUT_MS, int(1000 * request.request_timeout))
if request.user_agent:
curl.setopt(pycurl.USERAGENT, native_str(request.user_agent))
else:
curl.setopt(pycurl.USERAGENT, "Mozilla/5.0 (compatible; pycurl)")
if request.network_interface:
curl.setopt(pycurl.INTERFACE, request.network_interface)
if request.decompress_response:
curl.setopt(pycurl.ENCODING, "gzip,deflate")
else:
curl.setopt(pycurl.ENCODING, "none")
if request.proxy_host and request.proxy_port:
curl.setopt(pycurl.PROXY, request.proxy_host)
curl.setopt(pycurl.PROXYPORT, request.proxy_port)
if request.proxy_username:
assert request.proxy_password is not None
credentials = httputil.encode_username_password(
request.proxy_username, request.proxy_password
)
curl.setopt(pycurl.PROXYUSERPWD, credentials)
if request.proxy_auth_mode is None or request.proxy_auth_mode == "basic":
curl.setopt(pycurl.PROXYAUTH, pycurl.HTTPAUTH_BASIC)
elif request.proxy_auth_mode == "digest":
curl.setopt(pycurl.PROXYAUTH, pycurl.HTTPAUTH_DIGEST)
else:
raise ValueError(
"Unsupported proxy_auth_mode %s" % request.proxy_auth_mode
)
else:
curl.setopt(pycurl.PROXY, "")
curl.unsetopt(pycurl.PROXYUSERPWD)
if request.validate_cert:
curl.setopt(pycurl.SSL_VERIFYPEER, 1)
curl.setopt(pycurl.SSL_VERIFYHOST, 2)
else:
curl.setopt(pycurl.SSL_VERIFYPEER, 0)
curl.setopt(pycurl.SSL_VERIFYHOST, 0)
if request.ca_certs is not None:
curl.setopt(pycurl.CAINFO, request.ca_certs)
else:
# There is no way to restore pycurl.CAINFO to its default value
# (Using unsetopt makes it reject all certificates).
# I don't see any way to read the default value from python so it
# can be restored later. We'll have to just leave CAINFO untouched
# if no ca_certs file was specified, and require that if any
# request uses a custom ca_certs file, they all must.
pass
if request.allow_ipv6 is False:
# Curl behaves reasonably when DNS resolution gives an ipv6 address
# that we can't reach, so allow ipv6 unless the user asks to disable.
curl.setopt(pycurl.IPRESOLVE, pycurl.IPRESOLVE_V4)
else:
curl.setopt(pycurl.IPRESOLVE, pycurl.IPRESOLVE_WHATEVER)
# Set the request method through curl's irritating interface which makes
# up names for almost every single method
curl_options = {
"GET": pycurl.HTTPGET,
"POST": pycurl.POST,
"PUT": pycurl.UPLOAD,
"HEAD": pycurl.NOBODY,
}
custom_methods = set(["DELETE", "OPTIONS", "PATCH"])
for o in curl_options.values():
curl.setopt(o, False)
if request.method in curl_options:
curl.unsetopt(pycurl.CUSTOMREQUEST)
curl.setopt(curl_options[request.method], True)
elif request.allow_nonstandard_methods or request.method in custom_methods:
curl.setopt(pycurl.CUSTOMREQUEST, request.method)
else:
raise KeyError("unknown method " + request.method)
body_expected = request.method in ("POST", "PATCH", "PUT")
body_present = request.body is not None
if not request.allow_nonstandard_methods:
# Some HTTP methods nearly always have bodies while others
# almost never do. Fail in this case unless the user has
# opted out of sanity checks with allow_nonstandard_methods.
if (body_expected and not body_present) or (
body_present and not body_expected
):
raise ValueError(
"Body must %sbe None for method %s (unless "
"allow_nonstandard_methods is true)"
% ("not " if body_expected else "", request.method)
)
if body_expected or body_present:
if request.method == "GET":
# Even with `allow_nonstandard_methods` we disallow
# GET with a body (because libcurl doesn't allow it
# unless we use CUSTOMREQUEST). While the spec doesn't
# forbid clients from sending a body, it arguably
# disallows the server from doing anything with them.
raise ValueError("Body must be None for GET request")
request_buffer = BytesIO(utf8(request.body or ""))
def ioctl(cmd: int) -> None:
if cmd == curl.IOCMD_RESTARTREAD:
request_buffer.seek(0)
curl.setopt(pycurl.READFUNCTION, request_buffer.read)
curl.setopt(pycurl.IOCTLFUNCTION, ioctl)
if request.method == "POST":
curl.setopt(pycurl.POSTFIELDSIZE, len(request.body or ""))
else:
curl.setopt(pycurl.UPLOAD, True)
curl.setopt(pycurl.INFILESIZE, len(request.body or ""))
if request.auth_username is not None:
assert request.auth_password is not None
if request.auth_mode is None or request.auth_mode == "basic":
curl.setopt(pycurl.HTTPAUTH, pycurl.HTTPAUTH_BASIC)
elif request.auth_mode == "digest":
curl.setopt(pycurl.HTTPAUTH, pycurl.HTTPAUTH_DIGEST)
else:
raise ValueError("Unsupported auth_mode %s" % request.auth_mode)
userpwd = httputil.encode_username_password(
request.auth_username, request.auth_password
)
curl.setopt(pycurl.USERPWD, userpwd)
curl_log.debug(
"%s %s (username: %r)",
request.method,
request.url,
request.auth_username,
)
else:
curl.unsetopt(pycurl.USERPWD)
curl_log.debug("%s %s", request.method, request.url)
if request.client_cert is not None:
curl.setopt(pycurl.SSLCERT, request.client_cert)
if request.client_key is not None:
curl.setopt(pycurl.SSLKEY, request.client_key)
if request.ssl_options is not None:
raise ValueError("ssl_options not supported in curl_httpclient")
if threading.active_count() > 1:
# libcurl/pycurl is not thread-safe by default. When multiple threads
# are used, signals should be disabled. This has the side effect
# of disabling DNS timeouts in some environments (when libcurl is
# not linked against ares), so we don't do it when there is only one
# thread. Applications that use many short-lived threads may need
# to set NOSIGNAL manually in a prepare_curl_callback since
# there may not be any other threads running at the time we call
# threading.activeCount.
curl.setopt(pycurl.NOSIGNAL, 1)
if request.prepare_curl_callback is not None:
request.prepare_curl_callback(curl)
def _curl_header_callback(
self,
headers: httputil.HTTPHeaders,
header_callback: Callable[[str], None],
header_line_bytes: bytes,
) -> None:
header_line = native_str(header_line_bytes.decode("latin1"))
if header_callback is not None:
self.io_loop.add_callback(header_callback, header_line)
# header_line as returned by curl includes the end-of-line characters.
# whitespace at the start should be preserved to allow multi-line headers
header_line = header_line.rstrip()
if header_line.startswith("HTTP/"):
headers.clear()
try:
(__, __, reason) = httputil.parse_response_start_line(header_line)
header_line = "X-Http-Reason: %s" % reason
except httputil.HTTPInputError:
return
if not header_line:
return
headers.parse_line(header_line)
def _curl_debug(self, debug_type: int, debug_msg: str) -> None:
debug_types = ("I", "<", ">", "<", ">")
if debug_type == 0:
debug_msg = native_str(debug_msg)
curl_log.debug("%s", debug_msg.strip())
elif debug_type in (1, 2):
debug_msg = native_str(debug_msg)
for line in debug_msg.splitlines():
curl_log.debug("%s %s", debug_types[debug_type], line)
elif debug_type == 4:
curl_log.debug("%s %r", debug_types[debug_type], debug_msg)
class CurlError(HTTPError):
def __init__(self, errno: int, message: str) -> None:
HTTPError.__init__(self, 599, message)
self.errno = errno
if __name__ == "__main__":
AsyncHTTPClient.configure(CurlAsyncHTTPClient)
main()

View File

@@ -0,0 +1,400 @@
#
# Copyright 2009 Facebook
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""Escaping/unescaping methods for HTML, JSON, URLs, and others.
Also includes a few other miscellaneous string manipulation functions that
have crept in over time.
"""
import html.entities
import json
import re
import urllib.parse
from tornado.util import unicode_type
import typing
from typing import Union, Any, Optional, Dict, List, Callable
_XHTML_ESCAPE_RE = re.compile("[&<>\"']")
_XHTML_ESCAPE_DICT = {
"&": "&amp;",
"<": "&lt;",
">": "&gt;",
'"': "&quot;",
"'": "&#39;",
}
def xhtml_escape(value: Union[str, bytes]) -> str:
"""Escapes a string so it is valid within HTML or XML.
Escapes the characters ``<``, ``>``, ``"``, ``'``, and ``&``.
When used in attribute values the escaped strings must be enclosed
in quotes.
.. versionchanged:: 3.2
Added the single quote to the list of escaped characters.
"""
return _XHTML_ESCAPE_RE.sub(
lambda match: _XHTML_ESCAPE_DICT[match.group(0)], to_basestring(value)
)
def xhtml_unescape(value: Union[str, bytes]) -> str:
"""Un-escapes an XML-escaped string."""
return re.sub(r"&(#?)(\w+?);", _convert_entity, _unicode(value))
# The fact that json_encode wraps json.dumps is an implementation detail.
# Please see https://github.com/tornadoweb/tornado/pull/706
# before sending a pull request that adds **kwargs to this function.
def json_encode(value: Any) -> str:
"""JSON-encodes the given Python object."""
# JSON permits but does not require forward slashes to be escaped.
# This is useful when json data is emitted in a <script> tag
# in HTML, as it prevents </script> tags from prematurely terminating
# the javascript. Some json libraries do this escaping by default,
# although python's standard library does not, so we do it here.
# http://stackoverflow.com/questions/1580647/json-why-are-forward-slashes-escaped
return json.dumps(value).replace("</", "<\\/")
def json_decode(value: Union[str, bytes]) -> Any:
"""Returns Python objects for the given JSON string.
Supports both `str` and `bytes` inputs.
"""
return json.loads(to_basestring(value))
def squeeze(value: str) -> str:
"""Replace all sequences of whitespace chars with a single space."""
return re.sub(r"[\x00-\x20]+", " ", value).strip()
def url_escape(value: Union[str, bytes], plus: bool = True) -> str:
"""Returns a URL-encoded version of the given value.
If ``plus`` is true (the default), spaces will be represented
as "+" instead of "%20". This is appropriate for query strings
but not for the path component of a URL. Note that this default
is the reverse of Python's urllib module.
.. versionadded:: 3.1
The ``plus`` argument
"""
quote = urllib.parse.quote_plus if plus else urllib.parse.quote
return quote(utf8(value))
@typing.overload
def url_unescape(value: Union[str, bytes], encoding: None, plus: bool = True) -> bytes:
pass
@typing.overload # noqa: F811
def url_unescape(
value: Union[str, bytes], encoding: str = "utf-8", plus: bool = True
) -> str:
pass
def url_unescape( # noqa: F811
value: Union[str, bytes], encoding: Optional[str] = "utf-8", plus: bool = True
) -> Union[str, bytes]:
"""Decodes the given value from a URL.
The argument may be either a byte or unicode string.
If encoding is None, the result will be a byte string. Otherwise,
the result is a unicode string in the specified encoding.
If ``plus`` is true (the default), plus signs will be interpreted
as spaces (literal plus signs must be represented as "%2B"). This
is appropriate for query strings and form-encoded values but not
for the path component of a URL. Note that this default is the
reverse of Python's urllib module.
.. versionadded:: 3.1
The ``plus`` argument
"""
if encoding is None:
if plus:
# unquote_to_bytes doesn't have a _plus variant
value = to_basestring(value).replace("+", " ")
return urllib.parse.unquote_to_bytes(value)
else:
unquote = urllib.parse.unquote_plus if plus else urllib.parse.unquote
return unquote(to_basestring(value), encoding=encoding)
def parse_qs_bytes(
qs: str, keep_blank_values: bool = False, strict_parsing: bool = False
) -> Dict[str, List[bytes]]:
"""Parses a query string like urlparse.parse_qs, but returns the
values as byte strings.
Keys still become type str (interpreted as latin1 in python3!)
because it's too painful to keep them as byte strings in
python3 and in practice they're nearly always ascii anyway.
"""
# This is gross, but python3 doesn't give us another way.
# Latin1 is the universal donor of character encodings.
result = urllib.parse.parse_qs(
qs, keep_blank_values, strict_parsing, encoding="latin1", errors="strict"
)
encoded = {}
for k, v in result.items():
encoded[k] = [i.encode("latin1") for i in v]
return encoded
_UTF8_TYPES = (bytes, type(None))
@typing.overload
def utf8(value: bytes) -> bytes:
pass
@typing.overload # noqa: F811
def utf8(value: str) -> bytes:
pass
@typing.overload # noqa: F811
def utf8(value: None) -> None:
pass
def utf8(value: Union[None, str, bytes]) -> Optional[bytes]: # noqa: F811
"""Converts a string argument to a byte string.
If the argument is already a byte string or None, it is returned unchanged.
Otherwise it must be a unicode string and is encoded as utf8.
"""
if isinstance(value, _UTF8_TYPES):
return value
if not isinstance(value, unicode_type):
raise TypeError("Expected bytes, unicode, or None; got %r" % type(value))
return value.encode("utf-8")
_TO_UNICODE_TYPES = (unicode_type, type(None))
@typing.overload
def to_unicode(value: str) -> str:
pass
@typing.overload # noqa: F811
def to_unicode(value: bytes) -> str:
pass
@typing.overload # noqa: F811
def to_unicode(value: None) -> None:
pass
def to_unicode(value: Union[None, str, bytes]) -> Optional[str]: # noqa: F811
"""Converts a string argument to a unicode string.
If the argument is already a unicode string or None, it is returned
unchanged. Otherwise it must be a byte string and is decoded as utf8.
"""
if isinstance(value, _TO_UNICODE_TYPES):
return value
if not isinstance(value, bytes):
raise TypeError("Expected bytes, unicode, or None; got %r" % type(value))
return value.decode("utf-8")
# to_unicode was previously named _unicode not because it was private,
# but to avoid conflicts with the built-in unicode() function/type
_unicode = to_unicode
# When dealing with the standard library across python 2 and 3 it is
# sometimes useful to have a direct conversion to the native string type
native_str = to_unicode
to_basestring = to_unicode
def recursive_unicode(obj: Any) -> Any:
"""Walks a simple data structure, converting byte strings to unicode.
Supports lists, tuples, and dictionaries.
"""
if isinstance(obj, dict):
return dict(
(recursive_unicode(k), recursive_unicode(v)) for (k, v) in obj.items()
)
elif isinstance(obj, list):
return list(recursive_unicode(i) for i in obj)
elif isinstance(obj, tuple):
return tuple(recursive_unicode(i) for i in obj)
elif isinstance(obj, bytes):
return to_unicode(obj)
else:
return obj
# I originally used the regex from
# http://daringfireball.net/2010/07/improved_regex_for_matching_urls
# but it gets all exponential on certain patterns (such as too many trailing
# dots), causing the regex matcher to never return.
# This regex should avoid those problems.
# Use to_unicode instead of tornado.util.u - we don't want backslashes getting
# processed as escapes.
_URL_RE = re.compile(
to_unicode(
r"""\b((?:([\w-]+):(/{1,3})|www[.])(?:(?:(?:[^\s&()]|&amp;|&quot;)*(?:[^!"#$%&'()*+,.:;<=>?@\[\]^`{|}~\s]))|(?:\((?:[^\s&()]|&amp;|&quot;)*\)))+)""" # noqa: E501
)
)
def linkify(
text: Union[str, bytes],
shorten: bool = False,
extra_params: Union[str, Callable[[str], str]] = "",
require_protocol: bool = False,
permitted_protocols: List[str] = ["http", "https"],
) -> str:
"""Converts plain text into HTML with links.
For example: ``linkify("Hello http://tornadoweb.org!")`` would return
``Hello <a href="http://tornadoweb.org">http://tornadoweb.org</a>!``
Parameters:
* ``shorten``: Long urls will be shortened for display.
* ``extra_params``: Extra text to include in the link tag, or a callable
taking the link as an argument and returning the extra text
e.g. ``linkify(text, extra_params='rel="nofollow" class="external"')``,
or::
def extra_params_cb(url):
if url.startswith("http://example.com"):
return 'class="internal"'
else:
return 'class="external" rel="nofollow"'
linkify(text, extra_params=extra_params_cb)
* ``require_protocol``: Only linkify urls which include a protocol. If
this is False, urls such as www.facebook.com will also be linkified.
* ``permitted_protocols``: List (or set) of protocols which should be
linkified, e.g. ``linkify(text, permitted_protocols=["http", "ftp",
"mailto"])``. It is very unsafe to include protocols such as
``javascript``.
"""
if extra_params and not callable(extra_params):
extra_params = " " + extra_params.strip()
def make_link(m: typing.Match) -> str:
url = m.group(1)
proto = m.group(2)
if require_protocol and not proto:
return url # not protocol, no linkify
if proto and proto not in permitted_protocols:
return url # bad protocol, no linkify
href = m.group(1)
if not proto:
href = "http://" + href # no proto specified, use http
if callable(extra_params):
params = " " + extra_params(href).strip()
else:
params = extra_params
# clip long urls. max_len is just an approximation
max_len = 30
if shorten and len(url) > max_len:
before_clip = url
if proto:
proto_len = len(proto) + 1 + len(m.group(3) or "") # +1 for :
else:
proto_len = 0
parts = url[proto_len:].split("/")
if len(parts) > 1:
# Grab the whole host part plus the first bit of the path
# The path is usually not that interesting once shortened
# (no more slug, etc), so it really just provides a little
# extra indication of shortening.
url = (
url[:proto_len]
+ parts[0]
+ "/"
+ parts[1][:8].split("?")[0].split(".")[0]
)
if len(url) > max_len * 1.5: # still too long
url = url[:max_len]
if url != before_clip:
amp = url.rfind("&")
# avoid splitting html char entities
if amp > max_len - 5:
url = url[:amp]
url += "..."
if len(url) >= len(before_clip):
url = before_clip
else:
# full url is visible on mouse-over (for those who don't
# have a status bar, such as Safari by default)
params += ' title="%s"' % href
return u'<a href="%s"%s>%s</a>' % (href, params, url)
# First HTML-escape so that our strings are all safe.
# The regex is modified to avoid character entites other than &amp; so
# that we won't pick up &quot;, etc.
text = _unicode(xhtml_escape(text))
return _URL_RE.sub(make_link, text)
def _convert_entity(m: typing.Match) -> str:
if m.group(1) == "#":
try:
if m.group(2)[:1].lower() == "x":
return chr(int(m.group(2)[1:], 16))
else:
return chr(int(m.group(2)))
except ValueError:
return "&#%s;" % m.group(2)
try:
return _HTML_UNICODE_MAP[m.group(2)]
except KeyError:
return "&%s;" % m.group(2)
def _build_unicode_map() -> Dict[str, str]:
unicode_map = {}
for name, value in html.entities.name2codepoint.items():
unicode_map[name] = chr(value)
return unicode_map
_HTML_UNICODE_MAP = _build_unicode_map()

View File

@@ -0,0 +1,845 @@
"""``tornado.gen`` implements generator-based coroutines.
.. note::
The "decorator and generator" approach in this module is a
precursor to native coroutines (using ``async def`` and ``await``)
which were introduced in Python 3.5. Applications that do not
require compatibility with older versions of Python should use
native coroutines instead. Some parts of this module are still
useful with native coroutines, notably `multi`, `sleep`,
`WaitIterator`, and `with_timeout`. Some of these functions have
counterparts in the `asyncio` module which may be used as well,
although the two may not necessarily be 100% compatible.
Coroutines provide an easier way to work in an asynchronous
environment than chaining callbacks. Code using coroutines is
technically asynchronous, but it is written as a single generator
instead of a collection of separate functions.
For example, here's a coroutine-based handler:
.. testcode::
class GenAsyncHandler(RequestHandler):
@gen.coroutine
def get(self):
http_client = AsyncHTTPClient()
response = yield http_client.fetch("http://example.com")
do_something_with_response(response)
self.render("template.html")
.. testoutput::
:hide:
Asynchronous functions in Tornado return an ``Awaitable`` or `.Future`;
yielding this object returns its result.
You can also yield a list or dict of other yieldable objects, which
will be started at the same time and run in parallel; a list or dict
of results will be returned when they are all finished:
.. testcode::
@gen.coroutine
def get(self):
http_client = AsyncHTTPClient()
response1, response2 = yield [http_client.fetch(url1),
http_client.fetch(url2)]
response_dict = yield dict(response3=http_client.fetch(url3),
response4=http_client.fetch(url4))
response3 = response_dict['response3']
response4 = response_dict['response4']
.. testoutput::
:hide:
If ``tornado.platform.twisted`` is imported, it is also possible to
yield Twisted's ``Deferred`` objects. See the `convert_yielded`
function to extend this mechanism.
.. versionchanged:: 3.2
Dict support added.
.. versionchanged:: 4.1
Support added for yielding ``asyncio`` Futures and Twisted Deferreds
via ``singledispatch``.
"""
import asyncio
import builtins
import collections
from collections.abc import Generator
import concurrent.futures
import datetime
import functools
from functools import singledispatch
from inspect import isawaitable
import sys
import types
from tornado.concurrent import (
Future,
is_future,
chain_future,
future_set_exc_info,
future_add_done_callback,
future_set_result_unless_cancelled,
)
from tornado.ioloop import IOLoop
from tornado.log import app_log
from tornado.util import TimeoutError
import typing
from typing import Union, Any, Callable, List, Type, Tuple, Awaitable, Dict
if typing.TYPE_CHECKING:
from typing import Sequence, Deque, Optional, Set, Iterable # noqa: F401
_T = typing.TypeVar("_T")
_Yieldable = Union[
None, Awaitable, List[Awaitable], Dict[Any, Awaitable], concurrent.futures.Future
]
class KeyReuseError(Exception):
pass
class UnknownKeyError(Exception):
pass
class LeakedCallbackError(Exception):
pass
class BadYieldError(Exception):
pass
class ReturnValueIgnoredError(Exception):
pass
def _value_from_stopiteration(e: Union[StopIteration, "Return"]) -> Any:
try:
# StopIteration has a value attribute beginning in py33.
# So does our Return class.
return e.value
except AttributeError:
pass
try:
# Cython backports coroutine functionality by putting the value in
# e.args[0].
return e.args[0]
except (AttributeError, IndexError):
return None
def _create_future() -> Future:
future = Future() # type: Future
# Fixup asyncio debug info by removing extraneous stack entries
source_traceback = getattr(future, "_source_traceback", ())
while source_traceback:
# Each traceback entry is equivalent to a
# (filename, self.lineno, self.name, self.line) tuple
filename = source_traceback[-1][0]
if filename == __file__:
del source_traceback[-1]
else:
break
return future
def coroutine(
func: Callable[..., "Generator[Any, Any, _T]"]
) -> Callable[..., "Future[_T]"]:
"""Decorator for asynchronous generators.
For compatibility with older versions of Python, coroutines may
also "return" by raising the special exception `Return(value)
<Return>`.
Functions with this decorator return a `.Future`.
.. warning::
When exceptions occur inside a coroutine, the exception
information will be stored in the `.Future` object. You must
examine the result of the `.Future` object, or the exception
may go unnoticed by your code. This means yielding the function
if called from another coroutine, using something like
`.IOLoop.run_sync` for top-level calls, or passing the `.Future`
to `.IOLoop.add_future`.
.. versionchanged:: 6.0
The ``callback`` argument was removed. Use the returned
awaitable object instead.
"""
@functools.wraps(func)
def wrapper(*args, **kwargs):
# type: (*Any, **Any) -> Future[_T]
# This function is type-annotated with a comment to work around
# https://bitbucket.org/pypy/pypy/issues/2868/segfault-with-args-type-annotation-in
future = _create_future()
try:
result = func(*args, **kwargs)
except (Return, StopIteration) as e:
result = _value_from_stopiteration(e)
except Exception:
future_set_exc_info(future, sys.exc_info())
try:
return future
finally:
# Avoid circular references
future = None # type: ignore
else:
if isinstance(result, Generator):
# Inline the first iteration of Runner.run. This lets us
# avoid the cost of creating a Runner when the coroutine
# never actually yields, which in turn allows us to
# use "optional" coroutines in critical path code without
# performance penalty for the synchronous case.
try:
yielded = next(result)
except (StopIteration, Return) as e:
future_set_result_unless_cancelled(
future, _value_from_stopiteration(e)
)
except Exception:
future_set_exc_info(future, sys.exc_info())
else:
# Provide strong references to Runner objects as long
# as their result future objects also have strong
# references (typically from the parent coroutine's
# Runner). This keeps the coroutine's Runner alive.
# We do this by exploiting the public API
# add_done_callback() instead of putting a private
# attribute on the Future.
# (Github issues #1769, #2229).
runner = Runner(result, future, yielded)
future.add_done_callback(lambda _: runner)
yielded = None
try:
return future
finally:
# Subtle memory optimization: if next() raised an exception,
# the future's exc_info contains a traceback which
# includes this stack frame. This creates a cycle,
# which will be collected at the next full GC but has
# been shown to greatly increase memory usage of
# benchmarks (relative to the refcount-based scheme
# used in the absence of cycles). We can avoid the
# cycle by clearing the local variable after we return it.
future = None # type: ignore
future_set_result_unless_cancelled(future, result)
return future
wrapper.__wrapped__ = func # type: ignore
wrapper.__tornado_coroutine__ = True # type: ignore
return wrapper
def is_coroutine_function(func: Any) -> bool:
"""Return whether *func* is a coroutine function, i.e. a function
wrapped with `~.gen.coroutine`.
.. versionadded:: 4.5
"""
return getattr(func, "__tornado_coroutine__", False)
class Return(Exception):
"""Special exception to return a value from a `coroutine`.
If this exception is raised, its value argument is used as the
result of the coroutine::
@gen.coroutine
def fetch_json(url):
response = yield AsyncHTTPClient().fetch(url)
raise gen.Return(json_decode(response.body))
In Python 3.3, this exception is no longer necessary: the ``return``
statement can be used directly to return a value (previously
``yield`` and ``return`` with a value could not be combined in the
same function).
By analogy with the return statement, the value argument is optional,
but it is never necessary to ``raise gen.Return()``. The ``return``
statement can be used with no arguments instead.
"""
def __init__(self, value: Any = None) -> None:
super(Return, self).__init__()
self.value = value
# Cython recognizes subclasses of StopIteration with a .args tuple.
self.args = (value,)
class WaitIterator(object):
"""Provides an iterator to yield the results of awaitables as they finish.
Yielding a set of awaitables like this:
``results = yield [awaitable1, awaitable2]``
pauses the coroutine until both ``awaitable1`` and ``awaitable2``
return, and then restarts the coroutine with the results of both
awaitables. If either awaitable raises an exception, the
expression will raise that exception and all the results will be
lost.
If you need to get the result of each awaitable as soon as possible,
or if you need the result of some awaitables even if others produce
errors, you can use ``WaitIterator``::
wait_iterator = gen.WaitIterator(awaitable1, awaitable2)
while not wait_iterator.done():
try:
result = yield wait_iterator.next()
except Exception as e:
print("Error {} from {}".format(e, wait_iterator.current_future))
else:
print("Result {} received from {} at {}".format(
result, wait_iterator.current_future,
wait_iterator.current_index))
Because results are returned as soon as they are available the
output from the iterator *will not be in the same order as the
input arguments*. If you need to know which future produced the
current result, you can use the attributes
``WaitIterator.current_future``, or ``WaitIterator.current_index``
to get the index of the awaitable from the input list. (if keyword
arguments were used in the construction of the `WaitIterator`,
``current_index`` will use the corresponding keyword).
On Python 3.5, `WaitIterator` implements the async iterator
protocol, so it can be used with the ``async for`` statement (note
that in this version the entire iteration is aborted if any value
raises an exception, while the previous example can continue past
individual errors)::
async for result in gen.WaitIterator(future1, future2):
print("Result {} received from {} at {}".format(
result, wait_iterator.current_future,
wait_iterator.current_index))
.. versionadded:: 4.1
.. versionchanged:: 4.3
Added ``async for`` support in Python 3.5.
"""
_unfinished = {} # type: Dict[Future, Union[int, str]]
def __init__(self, *args: Future, **kwargs: Future) -> None:
if args and kwargs:
raise ValueError("You must provide args or kwargs, not both")
if kwargs:
self._unfinished = dict((f, k) for (k, f) in kwargs.items())
futures = list(kwargs.values()) # type: Sequence[Future]
else:
self._unfinished = dict((f, i) for (i, f) in enumerate(args))
futures = args
self._finished = collections.deque() # type: Deque[Future]
self.current_index = None # type: Optional[Union[str, int]]
self.current_future = None # type: Optional[Future]
self._running_future = None # type: Optional[Future]
for future in futures:
future_add_done_callback(future, self._done_callback)
def done(self) -> bool:
"""Returns True if this iterator has no more results."""
if self._finished or self._unfinished:
return False
# Clear the 'current' values when iteration is done.
self.current_index = self.current_future = None
return True
def next(self) -> Future:
"""Returns a `.Future` that will yield the next available result.
Note that this `.Future` will not be the same object as any of
the inputs.
"""
self._running_future = Future()
if self._finished:
self._return_result(self._finished.popleft())
return self._running_future
def _done_callback(self, done: Future) -> None:
if self._running_future and not self._running_future.done():
self._return_result(done)
else:
self._finished.append(done)
def _return_result(self, done: Future) -> None:
"""Called set the returned future's state that of the future
we yielded, and set the current future for the iterator.
"""
if self._running_future is None:
raise Exception("no future is running")
chain_future(done, self._running_future)
self.current_future = done
self.current_index = self._unfinished.pop(done)
def __aiter__(self) -> typing.AsyncIterator:
return self
def __anext__(self) -> Future:
if self.done():
# Lookup by name to silence pyflakes on older versions.
raise getattr(builtins, "StopAsyncIteration")()
return self.next()
def multi(
children: Union[List[_Yieldable], Dict[Any, _Yieldable]],
quiet_exceptions: "Union[Type[Exception], Tuple[Type[Exception], ...]]" = (),
) -> "Union[Future[List], Future[Dict]]":
"""Runs multiple asynchronous operations in parallel.
``children`` may either be a list or a dict whose values are
yieldable objects. ``multi()`` returns a new yieldable
object that resolves to a parallel structure containing their
results. If ``children`` is a list, the result is a list of
results in the same order; if it is a dict, the result is a dict
with the same keys.
That is, ``results = yield multi(list_of_futures)`` is equivalent
to::
results = []
for future in list_of_futures:
results.append(yield future)
If any children raise exceptions, ``multi()`` will raise the first
one. All others will be logged, unless they are of types
contained in the ``quiet_exceptions`` argument.
In a ``yield``-based coroutine, it is not normally necessary to
call this function directly, since the coroutine runner will
do it automatically when a list or dict is yielded. However,
it is necessary in ``await``-based coroutines, or to pass
the ``quiet_exceptions`` argument.
This function is available under the names ``multi()`` and ``Multi()``
for historical reasons.
Cancelling a `.Future` returned by ``multi()`` does not cancel its
children. `asyncio.gather` is similar to ``multi()``, but it does
cancel its children.
.. versionchanged:: 4.2
If multiple yieldables fail, any exceptions after the first
(which is raised) will be logged. Added the ``quiet_exceptions``
argument to suppress this logging for selected exception types.
.. versionchanged:: 4.3
Replaced the class ``Multi`` and the function ``multi_future``
with a unified function ``multi``. Added support for yieldables
other than ``YieldPoint`` and `.Future`.
"""
return multi_future(children, quiet_exceptions=quiet_exceptions)
Multi = multi
def multi_future(
children: Union[List[_Yieldable], Dict[Any, _Yieldable]],
quiet_exceptions: "Union[Type[Exception], Tuple[Type[Exception], ...]]" = (),
) -> "Union[Future[List], Future[Dict]]":
"""Wait for multiple asynchronous futures in parallel.
Since Tornado 6.0, this function is exactly the same as `multi`.
.. versionadded:: 4.0
.. versionchanged:: 4.2
If multiple ``Futures`` fail, any exceptions after the first (which is
raised) will be logged. Added the ``quiet_exceptions``
argument to suppress this logging for selected exception types.
.. deprecated:: 4.3
Use `multi` instead.
"""
if isinstance(children, dict):
keys = list(children.keys()) # type: Optional[List]
children_seq = children.values() # type: Iterable
else:
keys = None
children_seq = children
children_futs = list(map(convert_yielded, children_seq))
assert all(is_future(i) or isinstance(i, _NullFuture) for i in children_futs)
unfinished_children = set(children_futs)
future = _create_future()
if not children_futs:
future_set_result_unless_cancelled(future, {} if keys is not None else [])
def callback(fut: Future) -> None:
unfinished_children.remove(fut)
if not unfinished_children:
result_list = []
for f in children_futs:
try:
result_list.append(f.result())
except Exception as e:
if future.done():
if not isinstance(e, quiet_exceptions):
app_log.error(
"Multiple exceptions in yield list", exc_info=True
)
else:
future_set_exc_info(future, sys.exc_info())
if not future.done():
if keys is not None:
future_set_result_unless_cancelled(
future, dict(zip(keys, result_list))
)
else:
future_set_result_unless_cancelled(future, result_list)
listening = set() # type: Set[Future]
for f in children_futs:
if f not in listening:
listening.add(f)
future_add_done_callback(f, callback)
return future
def maybe_future(x: Any) -> Future:
"""Converts ``x`` into a `.Future`.
If ``x`` is already a `.Future`, it is simply returned; otherwise
it is wrapped in a new `.Future`. This is suitable for use as
``result = yield gen.maybe_future(f())`` when you don't know whether
``f()`` returns a `.Future` or not.
.. deprecated:: 4.3
This function only handles ``Futures``, not other yieldable objects.
Instead of `maybe_future`, check for the non-future result types
you expect (often just ``None``), and ``yield`` anything unknown.
"""
if is_future(x):
return x
else:
fut = _create_future()
fut.set_result(x)
return fut
def with_timeout(
timeout: Union[float, datetime.timedelta],
future: _Yieldable,
quiet_exceptions: "Union[Type[Exception], Tuple[Type[Exception], ...]]" = (),
) -> Future:
"""Wraps a `.Future` (or other yieldable object) in a timeout.
Raises `tornado.util.TimeoutError` if the input future does not
complete before ``timeout``, which may be specified in any form
allowed by `.IOLoop.add_timeout` (i.e. a `datetime.timedelta` or
an absolute time relative to `.IOLoop.time`)
If the wrapped `.Future` fails after it has timed out, the exception
will be logged unless it is either of a type contained in
``quiet_exceptions`` (which may be an exception type or a sequence of
types), or an ``asyncio.CancelledError``.
The wrapped `.Future` is not canceled when the timeout expires,
permitting it to be reused. `asyncio.wait_for` is similar to this
function but it does cancel the wrapped `.Future` on timeout.
.. versionadded:: 4.0
.. versionchanged:: 4.1
Added the ``quiet_exceptions`` argument and the logging of unhandled
exceptions.
.. versionchanged:: 4.4
Added support for yieldable objects other than `.Future`.
.. versionchanged:: 6.0.3
``asyncio.CancelledError`` is now always considered "quiet".
"""
# It's tempting to optimize this by cancelling the input future on timeout
# instead of creating a new one, but A) we can't know if we are the only
# one waiting on the input future, so cancelling it might disrupt other
# callers and B) concurrent futures can only be cancelled while they are
# in the queue, so cancellation cannot reliably bound our waiting time.
future_converted = convert_yielded(future)
result = _create_future()
chain_future(future_converted, result)
io_loop = IOLoop.current()
def error_callback(future: Future) -> None:
try:
future.result()
except asyncio.CancelledError:
pass
except Exception as e:
if not isinstance(e, quiet_exceptions):
app_log.error(
"Exception in Future %r after timeout", future, exc_info=True
)
def timeout_callback() -> None:
if not result.done():
result.set_exception(TimeoutError("Timeout"))
# In case the wrapped future goes on to fail, log it.
future_add_done_callback(future_converted, error_callback)
timeout_handle = io_loop.add_timeout(timeout, timeout_callback)
if isinstance(future_converted, Future):
# We know this future will resolve on the IOLoop, so we don't
# need the extra thread-safety of IOLoop.add_future (and we also
# don't care about StackContext here.
future_add_done_callback(
future_converted, lambda future: io_loop.remove_timeout(timeout_handle)
)
else:
# concurrent.futures.Futures may resolve on any thread, so we
# need to route them back to the IOLoop.
io_loop.add_future(
future_converted, lambda future: io_loop.remove_timeout(timeout_handle)
)
return result
def sleep(duration: float) -> "Future[None]":
"""Return a `.Future` that resolves after the given number of seconds.
When used with ``yield`` in a coroutine, this is a non-blocking
analogue to `time.sleep` (which should not be used in coroutines
because it is blocking)::
yield gen.sleep(0.5)
Note that calling this function on its own does nothing; you must
wait on the `.Future` it returns (usually by yielding it).
.. versionadded:: 4.1
"""
f = _create_future()
IOLoop.current().call_later(
duration, lambda: future_set_result_unless_cancelled(f, None)
)
return f
class _NullFuture(object):
"""_NullFuture resembles a Future that finished with a result of None.
It's not actually a `Future` to avoid depending on a particular event loop.
Handled as a special case in the coroutine runner.
We lie and tell the type checker that a _NullFuture is a Future so
we don't have to leak _NullFuture into lots of public APIs. But
this means that the type checker can't warn us when we're passing
a _NullFuture into a code path that doesn't understand what to do
with it.
"""
def result(self) -> None:
return None
def done(self) -> bool:
return True
# _null_future is used as a dummy value in the coroutine runner. It differs
# from moment in that moment always adds a delay of one IOLoop iteration
# while _null_future is processed as soon as possible.
_null_future = typing.cast(Future, _NullFuture())
moment = typing.cast(Future, _NullFuture())
moment.__doc__ = """A special object which may be yielded to allow the IOLoop to run for
one iteration.
This is not needed in normal use but it can be helpful in long-running
coroutines that are likely to yield Futures that are ready instantly.
Usage: ``yield gen.moment``
In native coroutines, the equivalent of ``yield gen.moment`` is
``await asyncio.sleep(0)``.
.. versionadded:: 4.0
.. deprecated:: 4.5
``yield None`` (or ``yield`` with no argument) is now equivalent to
``yield gen.moment``.
"""
class Runner(object):
"""Internal implementation of `tornado.gen.coroutine`.
Maintains information about pending callbacks and their results.
The results of the generator are stored in ``result_future`` (a
`.Future`)
"""
def __init__(
self,
gen: "Generator[_Yieldable, Any, _T]",
result_future: "Future[_T]",
first_yielded: _Yieldable,
) -> None:
self.gen = gen
self.result_future = result_future
self.future = _null_future # type: Union[None, Future]
self.running = False
self.finished = False
self.io_loop = IOLoop.current()
if self.handle_yield(first_yielded):
gen = result_future = first_yielded = None # type: ignore
self.run()
def run(self) -> None:
"""Starts or resumes the generator, running until it reaches a
yield point that is not ready.
"""
if self.running or self.finished:
return
try:
self.running = True
while True:
future = self.future
if future is None:
raise Exception("No pending future")
if not future.done():
return
self.future = None
try:
exc_info = None
try:
value = future.result()
except Exception:
exc_info = sys.exc_info()
future = None
if exc_info is not None:
try:
yielded = self.gen.throw(*exc_info) # type: ignore
finally:
# Break up a reference to itself
# for faster GC on CPython.
exc_info = None
else:
yielded = self.gen.send(value)
except (StopIteration, Return) as e:
self.finished = True
self.future = _null_future
future_set_result_unless_cancelled(
self.result_future, _value_from_stopiteration(e)
)
self.result_future = None # type: ignore
return
except Exception:
self.finished = True
self.future = _null_future
future_set_exc_info(self.result_future, sys.exc_info())
self.result_future = None # type: ignore
return
if not self.handle_yield(yielded):
return
yielded = None
finally:
self.running = False
def handle_yield(self, yielded: _Yieldable) -> bool:
try:
self.future = convert_yielded(yielded)
except BadYieldError:
self.future = Future()
future_set_exc_info(self.future, sys.exc_info())
if self.future is moment:
self.io_loop.add_callback(self.run)
return False
elif self.future is None:
raise Exception("no pending future")
elif not self.future.done():
def inner(f: Any) -> None:
# Break a reference cycle to speed GC.
f = None # noqa: F841
self.run()
self.io_loop.add_future(self.future, inner)
return False
return True
def handle_exception(
self, typ: Type[Exception], value: Exception, tb: types.TracebackType
) -> bool:
if not self.running and not self.finished:
self.future = Future()
future_set_exc_info(self.future, (typ, value, tb))
self.run()
return True
else:
return False
# Convert Awaitables into Futures.
try:
_wrap_awaitable = asyncio.ensure_future
except AttributeError:
# asyncio.ensure_future was introduced in Python 3.4.4, but
# Debian jessie still ships with 3.4.2 so try the old name.
_wrap_awaitable = getattr(asyncio, "async")
def convert_yielded(yielded: _Yieldable) -> Future:
"""Convert a yielded object into a `.Future`.
The default implementation accepts lists, dictionaries, and
Futures. This has the side effect of starting any coroutines that
did not start themselves, similar to `asyncio.ensure_future`.
If the `~functools.singledispatch` library is available, this function
may be extended to support additional types. For example::
@convert_yielded.register(asyncio.Future)
def _(asyncio_future):
return tornado.platform.asyncio.to_tornado_future(asyncio_future)
.. versionadded:: 4.1
"""
if yielded is None or yielded is moment:
return moment
elif yielded is _null_future:
return _null_future
elif isinstance(yielded, (list, dict)):
return multi(yielded) # type: ignore
elif is_future(yielded):
return typing.cast(Future, yielded)
elif isawaitable(yielded):
return _wrap_awaitable(yielded) # type: ignore
else:
raise BadYieldError("yielded unknown object %r" % (yielded,))
convert_yielded = singledispatch(convert_yielded)

View File

@@ -0,0 +1,836 @@
#
# Copyright 2014 Facebook
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""Client and server implementations of HTTP/1.x.
.. versionadded:: 4.0
"""
import asyncio
import logging
import re
import types
from tornado.concurrent import (
Future,
future_add_done_callback,
future_set_result_unless_cancelled,
)
from tornado.escape import native_str, utf8
from tornado import gen
from tornado import httputil
from tornado import iostream
from tornado.log import gen_log, app_log
from tornado.util import GzipDecompressor
from typing import cast, Optional, Type, Awaitable, Callable, Union, Tuple
class _QuietException(Exception):
def __init__(self) -> None:
pass
class _ExceptionLoggingContext(object):
"""Used with the ``with`` statement when calling delegate methods to
log any exceptions with the given logger. Any exceptions caught are
converted to _QuietException
"""
def __init__(self, logger: logging.Logger) -> None:
self.logger = logger
def __enter__(self) -> None:
pass
def __exit__(
self,
typ: "Optional[Type[BaseException]]",
value: Optional[BaseException],
tb: types.TracebackType,
) -> None:
if value is not None:
assert typ is not None
self.logger.error("Uncaught exception", exc_info=(typ, value, tb))
raise _QuietException
class HTTP1ConnectionParameters(object):
"""Parameters for `.HTTP1Connection` and `.HTTP1ServerConnection`.
"""
def __init__(
self,
no_keep_alive: bool = False,
chunk_size: int = None,
max_header_size: int = None,
header_timeout: float = None,
max_body_size: int = None,
body_timeout: float = None,
decompress: bool = False,
) -> None:
"""
:arg bool no_keep_alive: If true, always close the connection after
one request.
:arg int chunk_size: how much data to read into memory at once
:arg int max_header_size: maximum amount of data for HTTP headers
:arg float header_timeout: how long to wait for all headers (seconds)
:arg int max_body_size: maximum amount of data for body
:arg float body_timeout: how long to wait while reading body (seconds)
:arg bool decompress: if true, decode incoming
``Content-Encoding: gzip``
"""
self.no_keep_alive = no_keep_alive
self.chunk_size = chunk_size or 65536
self.max_header_size = max_header_size or 65536
self.header_timeout = header_timeout
self.max_body_size = max_body_size
self.body_timeout = body_timeout
self.decompress = decompress
class HTTP1Connection(httputil.HTTPConnection):
"""Implements the HTTP/1.x protocol.
This class can be on its own for clients, or via `HTTP1ServerConnection`
for servers.
"""
def __init__(
self,
stream: iostream.IOStream,
is_client: bool,
params: HTTP1ConnectionParameters = None,
context: object = None,
) -> None:
"""
:arg stream: an `.IOStream`
:arg bool is_client: client or server
:arg params: a `.HTTP1ConnectionParameters` instance or ``None``
:arg context: an opaque application-defined object that can be accessed
as ``connection.context``.
"""
self.is_client = is_client
self.stream = stream
if params is None:
params = HTTP1ConnectionParameters()
self.params = params
self.context = context
self.no_keep_alive = params.no_keep_alive
# The body limits can be altered by the delegate, so save them
# here instead of just referencing self.params later.
self._max_body_size = self.params.max_body_size or self.stream.max_buffer_size
self._body_timeout = self.params.body_timeout
# _write_finished is set to True when finish() has been called,
# i.e. there will be no more data sent. Data may still be in the
# stream's write buffer.
self._write_finished = False
# True when we have read the entire incoming body.
self._read_finished = False
# _finish_future resolves when all data has been written and flushed
# to the IOStream.
self._finish_future = Future() # type: Future[None]
# If true, the connection should be closed after this request
# (after the response has been written in the server side,
# and after it has been read in the client)
self._disconnect_on_finish = False
self._clear_callbacks()
# Save the start lines after we read or write them; they
# affect later processing (e.g. 304 responses and HEAD methods
# have content-length but no bodies)
self._request_start_line = None # type: Optional[httputil.RequestStartLine]
self._response_start_line = None # type: Optional[httputil.ResponseStartLine]
self._request_headers = None # type: Optional[httputil.HTTPHeaders]
# True if we are writing output with chunked encoding.
self._chunking_output = False
# While reading a body with a content-length, this is the
# amount left to read.
self._expected_content_remaining = None # type: Optional[int]
# A Future for our outgoing writes, returned by IOStream.write.
self._pending_write = None # type: Optional[Future[None]]
def read_response(self, delegate: httputil.HTTPMessageDelegate) -> Awaitable[bool]:
"""Read a single HTTP response.
Typical client-mode usage is to write a request using `write_headers`,
`write`, and `finish`, and then call ``read_response``.
:arg delegate: a `.HTTPMessageDelegate`
Returns a `.Future` that resolves to a bool after the full response has
been read. The result is true if the stream is still open.
"""
if self.params.decompress:
delegate = _GzipMessageDelegate(delegate, self.params.chunk_size)
return self._read_message(delegate)
async def _read_message(self, delegate: httputil.HTTPMessageDelegate) -> bool:
need_delegate_close = False
try:
header_future = self.stream.read_until_regex(
b"\r?\n\r?\n", max_bytes=self.params.max_header_size
)
if self.params.header_timeout is None:
header_data = await header_future
else:
try:
header_data = await gen.with_timeout(
self.stream.io_loop.time() + self.params.header_timeout,
header_future,
quiet_exceptions=iostream.StreamClosedError,
)
except gen.TimeoutError:
self.close()
return False
start_line_str, headers = self._parse_headers(header_data)
if self.is_client:
resp_start_line = httputil.parse_response_start_line(start_line_str)
self._response_start_line = resp_start_line
start_line = (
resp_start_line
) # type: Union[httputil.RequestStartLine, httputil.ResponseStartLine]
# TODO: this will need to change to support client-side keepalive
self._disconnect_on_finish = False
else:
req_start_line = httputil.parse_request_start_line(start_line_str)
self._request_start_line = req_start_line
self._request_headers = headers
start_line = req_start_line
self._disconnect_on_finish = not self._can_keep_alive(
req_start_line, headers
)
need_delegate_close = True
with _ExceptionLoggingContext(app_log):
header_recv_future = delegate.headers_received(start_line, headers)
if header_recv_future is not None:
await header_recv_future
if self.stream is None:
# We've been detached.
need_delegate_close = False
return False
skip_body = False
if self.is_client:
assert isinstance(start_line, httputil.ResponseStartLine)
if (
self._request_start_line is not None
and self._request_start_line.method == "HEAD"
):
skip_body = True
code = start_line.code
if code == 304:
# 304 responses may include the content-length header
# but do not actually have a body.
# http://tools.ietf.org/html/rfc7230#section-3.3
skip_body = True
if code >= 100 and code < 200:
# 1xx responses should never indicate the presence of
# a body.
if "Content-Length" in headers or "Transfer-Encoding" in headers:
raise httputil.HTTPInputError(
"Response code %d cannot have body" % code
)
# TODO: client delegates will get headers_received twice
# in the case of a 100-continue. Document or change?
await self._read_message(delegate)
else:
if headers.get("Expect") == "100-continue" and not self._write_finished:
self.stream.write(b"HTTP/1.1 100 (Continue)\r\n\r\n")
if not skip_body:
body_future = self._read_body(
resp_start_line.code if self.is_client else 0, headers, delegate
)
if body_future is not None:
if self._body_timeout is None:
await body_future
else:
try:
await gen.with_timeout(
self.stream.io_loop.time() + self._body_timeout,
body_future,
quiet_exceptions=iostream.StreamClosedError,
)
except gen.TimeoutError:
gen_log.info("Timeout reading body from %s", self.context)
self.stream.close()
return False
self._read_finished = True
if not self._write_finished or self.is_client:
need_delegate_close = False
with _ExceptionLoggingContext(app_log):
delegate.finish()
# If we're waiting for the application to produce an asynchronous
# response, and we're not detached, register a close callback
# on the stream (we didn't need one while we were reading)
if (
not self._finish_future.done()
and self.stream is not None
and not self.stream.closed()
):
self.stream.set_close_callback(self._on_connection_close)
await self._finish_future
if self.is_client and self._disconnect_on_finish:
self.close()
if self.stream is None:
return False
except httputil.HTTPInputError as e:
gen_log.info("Malformed HTTP message from %s: %s", self.context, e)
if not self.is_client:
await self.stream.write(b"HTTP/1.1 400 Bad Request\r\n\r\n")
self.close()
return False
finally:
if need_delegate_close:
with _ExceptionLoggingContext(app_log):
delegate.on_connection_close()
header_future = None # type: ignore
self._clear_callbacks()
return True
def _clear_callbacks(self) -> None:
"""Clears the callback attributes.
This allows the request handler to be garbage collected more
quickly in CPython by breaking up reference cycles.
"""
self._write_callback = None
self._write_future = None # type: Optional[Future[None]]
self._close_callback = None # type: Optional[Callable[[], None]]
if self.stream is not None:
self.stream.set_close_callback(None)
def set_close_callback(self, callback: Optional[Callable[[], None]]) -> None:
"""Sets a callback that will be run when the connection is closed.
Note that this callback is slightly different from
`.HTTPMessageDelegate.on_connection_close`: The
`.HTTPMessageDelegate` method is called when the connection is
closed while recieving a message. This callback is used when
there is not an active delegate (for example, on the server
side this callback is used if the client closes the connection
after sending its request but before receiving all the
response.
"""
self._close_callback = callback
def _on_connection_close(self) -> None:
# Note that this callback is only registered on the IOStream
# when we have finished reading the request and are waiting for
# the application to produce its response.
if self._close_callback is not None:
callback = self._close_callback
self._close_callback = None
callback()
if not self._finish_future.done():
future_set_result_unless_cancelled(self._finish_future, None)
self._clear_callbacks()
def close(self) -> None:
if self.stream is not None:
self.stream.close()
self._clear_callbacks()
if not self._finish_future.done():
future_set_result_unless_cancelled(self._finish_future, None)
def detach(self) -> iostream.IOStream:
"""Take control of the underlying stream.
Returns the underlying `.IOStream` object and stops all further
HTTP processing. May only be called during
`.HTTPMessageDelegate.headers_received`. Intended for implementing
protocols like websockets that tunnel over an HTTP handshake.
"""
self._clear_callbacks()
stream = self.stream
self.stream = None # type: ignore
if not self._finish_future.done():
future_set_result_unless_cancelled(self._finish_future, None)
return stream
def set_body_timeout(self, timeout: float) -> None:
"""Sets the body timeout for a single request.
Overrides the value from `.HTTP1ConnectionParameters`.
"""
self._body_timeout = timeout
def set_max_body_size(self, max_body_size: int) -> None:
"""Sets the body size limit for a single request.
Overrides the value from `.HTTP1ConnectionParameters`.
"""
self._max_body_size = max_body_size
def write_headers(
self,
start_line: Union[httputil.RequestStartLine, httputil.ResponseStartLine],
headers: httputil.HTTPHeaders,
chunk: bytes = None,
) -> "Future[None]":
"""Implements `.HTTPConnection.write_headers`."""
lines = []
if self.is_client:
assert isinstance(start_line, httputil.RequestStartLine)
self._request_start_line = start_line
lines.append(utf8("%s %s HTTP/1.1" % (start_line[0], start_line[1])))
# Client requests with a non-empty body must have either a
# Content-Length or a Transfer-Encoding.
self._chunking_output = (
start_line.method in ("POST", "PUT", "PATCH")
and "Content-Length" not in headers
and (
"Transfer-Encoding" not in headers
or headers["Transfer-Encoding"] == "chunked"
)
)
else:
assert isinstance(start_line, httputil.ResponseStartLine)
assert self._request_start_line is not None
assert self._request_headers is not None
self._response_start_line = start_line
lines.append(utf8("HTTP/1.1 %d %s" % (start_line[1], start_line[2])))
self._chunking_output = (
# TODO: should this use
# self._request_start_line.version or
# start_line.version?
self._request_start_line.version == "HTTP/1.1"
# 1xx, 204 and 304 responses have no body (not even a zero-length
# body), and so should not have either Content-Length or
# Transfer-Encoding headers.
and start_line.code not in (204, 304)
and (start_line.code < 100 or start_line.code >= 200)
# No need to chunk the output if a Content-Length is specified.
and "Content-Length" not in headers
# Applications are discouraged from touching Transfer-Encoding,
# but if they do, leave it alone.
and "Transfer-Encoding" not in headers
)
# If connection to a 1.1 client will be closed, inform client
if (
self._request_start_line.version == "HTTP/1.1"
and self._disconnect_on_finish
):
headers["Connection"] = "close"
# If a 1.0 client asked for keep-alive, add the header.
if (
self._request_start_line.version == "HTTP/1.0"
and self._request_headers.get("Connection", "").lower() == "keep-alive"
):
headers["Connection"] = "Keep-Alive"
if self._chunking_output:
headers["Transfer-Encoding"] = "chunked"
if not self.is_client and (
self._request_start_line.method == "HEAD"
or cast(httputil.ResponseStartLine, start_line).code == 304
):
self._expected_content_remaining = 0
elif "Content-Length" in headers:
self._expected_content_remaining = int(headers["Content-Length"])
else:
self._expected_content_remaining = None
# TODO: headers are supposed to be of type str, but we still have some
# cases that let bytes slip through. Remove these native_str calls when those
# are fixed.
header_lines = (
native_str(n) + ": " + native_str(v) for n, v in headers.get_all()
)
lines.extend(l.encode("latin1") for l in header_lines)
for line in lines:
if b"\n" in line:
raise ValueError("Newline in header: " + repr(line))
future = None
if self.stream.closed():
future = self._write_future = Future()
future.set_exception(iostream.StreamClosedError())
future.exception()
else:
future = self._write_future = Future()
data = b"\r\n".join(lines) + b"\r\n\r\n"
if chunk:
data += self._format_chunk(chunk)
self._pending_write = self.stream.write(data)
future_add_done_callback(self._pending_write, self._on_write_complete)
return future
def _format_chunk(self, chunk: bytes) -> bytes:
if self._expected_content_remaining is not None:
self._expected_content_remaining -= len(chunk)
if self._expected_content_remaining < 0:
# Close the stream now to stop further framing errors.
self.stream.close()
raise httputil.HTTPOutputError(
"Tried to write more data than Content-Length"
)
if self._chunking_output and chunk:
# Don't write out empty chunks because that means END-OF-STREAM
# with chunked encoding
return utf8("%x" % len(chunk)) + b"\r\n" + chunk + b"\r\n"
else:
return chunk
def write(self, chunk: bytes) -> "Future[None]":
"""Implements `.HTTPConnection.write`.
For backwards compatibility it is allowed but deprecated to
skip `write_headers` and instead call `write()` with a
pre-encoded header block.
"""
future = None
if self.stream.closed():
future = self._write_future = Future()
self._write_future.set_exception(iostream.StreamClosedError())
self._write_future.exception()
else:
future = self._write_future = Future()
self._pending_write = self.stream.write(self._format_chunk(chunk))
future_add_done_callback(self._pending_write, self._on_write_complete)
return future
def finish(self) -> None:
"""Implements `.HTTPConnection.finish`."""
if (
self._expected_content_remaining is not None
and self._expected_content_remaining != 0
and not self.stream.closed()
):
self.stream.close()
raise httputil.HTTPOutputError(
"Tried to write %d bytes less than Content-Length"
% self._expected_content_remaining
)
if self._chunking_output:
if not self.stream.closed():
self._pending_write = self.stream.write(b"0\r\n\r\n")
self._pending_write.add_done_callback(self._on_write_complete)
self._write_finished = True
# If the app finished the request while we're still reading,
# divert any remaining data away from the delegate and
# close the connection when we're done sending our response.
# Closing the connection is the only way to avoid reading the
# whole input body.
if not self._read_finished:
self._disconnect_on_finish = True
# No more data is coming, so instruct TCP to send any remaining
# data immediately instead of waiting for a full packet or ack.
self.stream.set_nodelay(True)
if self._pending_write is None:
self._finish_request(None)
else:
future_add_done_callback(self._pending_write, self._finish_request)
def _on_write_complete(self, future: "Future[None]") -> None:
exc = future.exception()
if exc is not None and not isinstance(exc, iostream.StreamClosedError):
future.result()
if self._write_callback is not None:
callback = self._write_callback
self._write_callback = None
self.stream.io_loop.add_callback(callback)
if self._write_future is not None:
future = self._write_future
self._write_future = None
future_set_result_unless_cancelled(future, None)
def _can_keep_alive(
self, start_line: httputil.RequestStartLine, headers: httputil.HTTPHeaders
) -> bool:
if self.params.no_keep_alive:
return False
connection_header = headers.get("Connection")
if connection_header is not None:
connection_header = connection_header.lower()
if start_line.version == "HTTP/1.1":
return connection_header != "close"
elif (
"Content-Length" in headers
or headers.get("Transfer-Encoding", "").lower() == "chunked"
or getattr(start_line, "method", None) in ("HEAD", "GET")
):
# start_line may be a request or response start line; only
# the former has a method attribute.
return connection_header == "keep-alive"
return False
def _finish_request(self, future: "Optional[Future[None]]") -> None:
self._clear_callbacks()
if not self.is_client and self._disconnect_on_finish:
self.close()
return
# Turn Nagle's algorithm back on, leaving the stream in its
# default state for the next request.
self.stream.set_nodelay(False)
if not self._finish_future.done():
future_set_result_unless_cancelled(self._finish_future, None)
def _parse_headers(self, data: bytes) -> Tuple[str, httputil.HTTPHeaders]:
# The lstrip removes newlines that some implementations sometimes
# insert between messages of a reused connection. Per RFC 7230,
# we SHOULD ignore at least one empty line before the request.
# http://tools.ietf.org/html/rfc7230#section-3.5
data_str = native_str(data.decode("latin1")).lstrip("\r\n")
# RFC 7230 section allows for both CRLF and bare LF.
eol = data_str.find("\n")
start_line = data_str[:eol].rstrip("\r")
headers = httputil.HTTPHeaders.parse(data_str[eol:])
return start_line, headers
def _read_body(
self,
code: int,
headers: httputil.HTTPHeaders,
delegate: httputil.HTTPMessageDelegate,
) -> Optional[Awaitable[None]]:
if "Content-Length" in headers:
if "Transfer-Encoding" in headers:
# Response cannot contain both Content-Length and
# Transfer-Encoding headers.
# http://tools.ietf.org/html/rfc7230#section-3.3.3
raise httputil.HTTPInputError(
"Response with both Transfer-Encoding and Content-Length"
)
if "," in headers["Content-Length"]:
# Proxies sometimes cause Content-Length headers to get
# duplicated. If all the values are identical then we can
# use them but if they differ it's an error.
pieces = re.split(r",\s*", headers["Content-Length"])
if any(i != pieces[0] for i in pieces):
raise httputil.HTTPInputError(
"Multiple unequal Content-Lengths: %r"
% headers["Content-Length"]
)
headers["Content-Length"] = pieces[0]
try:
content_length = int(headers["Content-Length"]) # type: Optional[int]
except ValueError:
# Handles non-integer Content-Length value.
raise httputil.HTTPInputError(
"Only integer Content-Length is allowed: %s"
% headers["Content-Length"]
)
if cast(int, content_length) > self._max_body_size:
raise httputil.HTTPInputError("Content-Length too long")
else:
content_length = None
if code == 204:
# This response code is not allowed to have a non-empty body,
# and has an implicit length of zero instead of read-until-close.
# http://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html#sec4.3
if "Transfer-Encoding" in headers or content_length not in (None, 0):
raise httputil.HTTPInputError(
"Response with code %d should not have body" % code
)
content_length = 0
if content_length is not None:
return self._read_fixed_body(content_length, delegate)
if headers.get("Transfer-Encoding", "").lower() == "chunked":
return self._read_chunked_body(delegate)
if self.is_client:
return self._read_body_until_close(delegate)
return None
async def _read_fixed_body(
self, content_length: int, delegate: httputil.HTTPMessageDelegate
) -> None:
while content_length > 0:
body = await self.stream.read_bytes(
min(self.params.chunk_size, content_length), partial=True
)
content_length -= len(body)
if not self._write_finished or self.is_client:
with _ExceptionLoggingContext(app_log):
ret = delegate.data_received(body)
if ret is not None:
await ret
async def _read_chunked_body(self, delegate: httputil.HTTPMessageDelegate) -> None:
# TODO: "chunk extensions" http://tools.ietf.org/html/rfc2616#section-3.6.1
total_size = 0
while True:
chunk_len_str = await self.stream.read_until(b"\r\n", max_bytes=64)
chunk_len = int(chunk_len_str.strip(), 16)
if chunk_len == 0:
crlf = await self.stream.read_bytes(2)
if crlf != b"\r\n":
raise httputil.HTTPInputError(
"improperly terminated chunked request"
)
return
total_size += chunk_len
if total_size > self._max_body_size:
raise httputil.HTTPInputError("chunked body too large")
bytes_to_read = chunk_len
while bytes_to_read:
chunk = await self.stream.read_bytes(
min(bytes_to_read, self.params.chunk_size), partial=True
)
bytes_to_read -= len(chunk)
if not self._write_finished or self.is_client:
with _ExceptionLoggingContext(app_log):
ret = delegate.data_received(chunk)
if ret is not None:
await ret
# chunk ends with \r\n
crlf = await self.stream.read_bytes(2)
assert crlf == b"\r\n"
async def _read_body_until_close(
self, delegate: httputil.HTTPMessageDelegate
) -> None:
body = await self.stream.read_until_close()
if not self._write_finished or self.is_client:
with _ExceptionLoggingContext(app_log):
ret = delegate.data_received(body)
if ret is not None:
await ret
class _GzipMessageDelegate(httputil.HTTPMessageDelegate):
"""Wraps an `HTTPMessageDelegate` to decode ``Content-Encoding: gzip``.
"""
def __init__(self, delegate: httputil.HTTPMessageDelegate, chunk_size: int) -> None:
self._delegate = delegate
self._chunk_size = chunk_size
self._decompressor = None # type: Optional[GzipDecompressor]
def headers_received(
self,
start_line: Union[httputil.RequestStartLine, httputil.ResponseStartLine],
headers: httputil.HTTPHeaders,
) -> Optional[Awaitable[None]]:
if headers.get("Content-Encoding") == "gzip":
self._decompressor = GzipDecompressor()
# Downstream delegates will only see uncompressed data,
# so rename the content-encoding header.
# (but note that curl_httpclient doesn't do this).
headers.add("X-Consumed-Content-Encoding", headers["Content-Encoding"])
del headers["Content-Encoding"]
return self._delegate.headers_received(start_line, headers)
async def data_received(self, chunk: bytes) -> None:
if self._decompressor:
compressed_data = chunk
while compressed_data:
decompressed = self._decompressor.decompress(
compressed_data, self._chunk_size
)
if decompressed:
ret = self._delegate.data_received(decompressed)
if ret is not None:
await ret
compressed_data = self._decompressor.unconsumed_tail
else:
ret = self._delegate.data_received(chunk)
if ret is not None:
await ret
def finish(self) -> None:
if self._decompressor is not None:
tail = self._decompressor.flush()
if tail:
# The tail should always be empty: decompress returned
# all that it can in data_received and the only
# purpose of the flush call is to detect errors such
# as truncated input. If we did legitimately get a new
# chunk at this point we'd need to change the
# interface to make finish() a coroutine.
raise ValueError(
"decompressor.flush returned data; possile truncated input"
)
return self._delegate.finish()
def on_connection_close(self) -> None:
return self._delegate.on_connection_close()
class HTTP1ServerConnection(object):
"""An HTTP/1.x server."""
def __init__(
self,
stream: iostream.IOStream,
params: HTTP1ConnectionParameters = None,
context: object = None,
) -> None:
"""
:arg stream: an `.IOStream`
:arg params: a `.HTTP1ConnectionParameters` or None
:arg context: an opaque application-defined object that is accessible
as ``connection.context``
"""
self.stream = stream
if params is None:
params = HTTP1ConnectionParameters()
self.params = params
self.context = context
self._serving_future = None # type: Optional[Future[None]]
async def close(self) -> None:
"""Closes the connection.
Returns a `.Future` that resolves after the serving loop has exited.
"""
self.stream.close()
# Block until the serving loop is done, but ignore any exceptions
# (start_serving is already responsible for logging them).
assert self._serving_future is not None
try:
await self._serving_future
except Exception:
pass
def start_serving(self, delegate: httputil.HTTPServerConnectionDelegate) -> None:
"""Starts serving requests on this connection.
:arg delegate: a `.HTTPServerConnectionDelegate`
"""
assert isinstance(delegate, httputil.HTTPServerConnectionDelegate)
fut = gen.convert_yielded(self._server_request_loop(delegate))
self._serving_future = fut
# Register the future on the IOLoop so its errors get logged.
self.stream.io_loop.add_future(fut, lambda f: f.result())
async def _server_request_loop(
self, delegate: httputil.HTTPServerConnectionDelegate
) -> None:
try:
while True:
conn = HTTP1Connection(self.stream, False, self.params, self.context)
request_delegate = delegate.start_request(self, conn)
try:
ret = await conn.read_response(request_delegate)
except (
iostream.StreamClosedError,
iostream.UnsatisfiableReadError,
asyncio.CancelledError,
):
return
except _QuietException:
# This exception was already logged.
conn.close()
return
except Exception:
gen_log.error("Uncaught exception", exc_info=True)
conn.close()
return
if not ret:
return
await asyncio.sleep(0)
finally:
delegate.on_close(self)

View File

@@ -0,0 +1,781 @@
"""Blocking and non-blocking HTTP client interfaces.
This module defines a common interface shared by two implementations,
``simple_httpclient`` and ``curl_httpclient``. Applications may either
instantiate their chosen implementation class directly or use the
`AsyncHTTPClient` class from this module, which selects an implementation
that can be overridden with the `AsyncHTTPClient.configure` method.
The default implementation is ``simple_httpclient``, and this is expected
to be suitable for most users' needs. However, some applications may wish
to switch to ``curl_httpclient`` for reasons such as the following:
* ``curl_httpclient`` has some features not found in ``simple_httpclient``,
including support for HTTP proxies and the ability to use a specified
network interface.
* ``curl_httpclient`` is more likely to be compatible with sites that are
not-quite-compliant with the HTTP spec, or sites that use little-exercised
features of HTTP.
* ``curl_httpclient`` is faster.
Note that if you are using ``curl_httpclient``, it is highly
recommended that you use a recent version of ``libcurl`` and
``pycurl``. Currently the minimum supported version of libcurl is
7.22.0, and the minimum version of pycurl is 7.18.2. It is highly
recommended that your ``libcurl`` installation is built with
asynchronous DNS resolver (threaded or c-ares), otherwise you may
encounter various problems with request timeouts (for more
information, see
http://curl.haxx.se/libcurl/c/curl_easy_setopt.html#CURLOPTCONNECTTIMEOUTMS
and comments in curl_httpclient.py).
To select ``curl_httpclient``, call `AsyncHTTPClient.configure` at startup::
AsyncHTTPClient.configure("tornado.curl_httpclient.CurlAsyncHTTPClient")
"""
import datetime
import functools
from io import BytesIO
import ssl
import time
import weakref
from tornado.concurrent import (
Future,
future_set_result_unless_cancelled,
future_set_exception_unless_cancelled,
)
from tornado.escape import utf8, native_str
from tornado import gen, httputil
from tornado.ioloop import IOLoop
from tornado.util import Configurable
from typing import Type, Any, Union, Dict, Callable, Optional, cast, Awaitable
class HTTPClient(object):
"""A blocking HTTP client.
This interface is provided to make it easier to share code between
synchronous and asynchronous applications. Applications that are
running an `.IOLoop` must use `AsyncHTTPClient` instead.
Typical usage looks like this::
http_client = httpclient.HTTPClient()
try:
response = http_client.fetch("http://www.google.com/")
print(response.body)
except httpclient.HTTPError as e:
# HTTPError is raised for non-200 responses; the response
# can be found in e.response.
print("Error: " + str(e))
except Exception as e:
# Other errors are possible, such as IOError.
print("Error: " + str(e))
http_client.close()
.. versionchanged:: 5.0
Due to limitations in `asyncio`, it is no longer possible to
use the synchronous ``HTTPClient`` while an `.IOLoop` is running.
Use `AsyncHTTPClient` instead.
"""
def __init__(
self, async_client_class: Type["AsyncHTTPClient"] = None, **kwargs: Any
) -> None:
# Initialize self._closed at the beginning of the constructor
# so that an exception raised here doesn't lead to confusing
# failures in __del__.
self._closed = True
self._io_loop = IOLoop(make_current=False)
if async_client_class is None:
async_client_class = AsyncHTTPClient
# Create the client while our IOLoop is "current", without
# clobbering the thread's real current IOLoop (if any).
async def make_client() -> "AsyncHTTPClient":
await gen.sleep(0)
assert async_client_class is not None
return async_client_class(**kwargs)
self._async_client = self._io_loop.run_sync(make_client)
self._closed = False
def __del__(self) -> None:
self.close()
def close(self) -> None:
"""Closes the HTTPClient, freeing any resources used."""
if not self._closed:
self._async_client.close()
self._io_loop.close()
self._closed = True
def fetch(
self, request: Union["HTTPRequest", str], **kwargs: Any
) -> "HTTPResponse":
"""Executes a request, returning an `HTTPResponse`.
The request may be either a string URL or an `HTTPRequest` object.
If it is a string, we construct an `HTTPRequest` using any additional
kwargs: ``HTTPRequest(request, **kwargs)``
If an error occurs during the fetch, we raise an `HTTPError` unless
the ``raise_error`` keyword argument is set to False.
"""
response = self._io_loop.run_sync(
functools.partial(self._async_client.fetch, request, **kwargs)
)
return response
class AsyncHTTPClient(Configurable):
"""An non-blocking HTTP client.
Example usage::
async def f():
http_client = AsyncHTTPClient()
try:
response = await http_client.fetch("http://www.google.com")
except Exception as e:
print("Error: %s" % e)
else:
print(response.body)
The constructor for this class is magic in several respects: It
actually creates an instance of an implementation-specific
subclass, and instances are reused as a kind of pseudo-singleton
(one per `.IOLoop`). The keyword argument ``force_instance=True``
can be used to suppress this singleton behavior. Unless
``force_instance=True`` is used, no arguments should be passed to
the `AsyncHTTPClient` constructor. The implementation subclass as
well as arguments to its constructor can be set with the static
method `configure()`
All `AsyncHTTPClient` implementations support a ``defaults``
keyword argument, which can be used to set default values for
`HTTPRequest` attributes. For example::
AsyncHTTPClient.configure(
None, defaults=dict(user_agent="MyUserAgent"))
# or with force_instance:
client = AsyncHTTPClient(force_instance=True,
defaults=dict(user_agent="MyUserAgent"))
.. versionchanged:: 5.0
The ``io_loop`` argument (deprecated since version 4.1) has been removed.
"""
_instance_cache = None # type: Dict[IOLoop, AsyncHTTPClient]
@classmethod
def configurable_base(cls) -> Type[Configurable]:
return AsyncHTTPClient
@classmethod
def configurable_default(cls) -> Type[Configurable]:
from tornado.simple_httpclient import SimpleAsyncHTTPClient
return SimpleAsyncHTTPClient
@classmethod
def _async_clients(cls) -> Dict[IOLoop, "AsyncHTTPClient"]:
attr_name = "_async_client_dict_" + cls.__name__
if not hasattr(cls, attr_name):
setattr(cls, attr_name, weakref.WeakKeyDictionary())
return getattr(cls, attr_name)
def __new__(cls, force_instance: bool = False, **kwargs: Any) -> "AsyncHTTPClient":
io_loop = IOLoop.current()
if force_instance:
instance_cache = None
else:
instance_cache = cls._async_clients()
if instance_cache is not None and io_loop in instance_cache:
return instance_cache[io_loop]
instance = super(AsyncHTTPClient, cls).__new__(cls, **kwargs) # type: ignore
# Make sure the instance knows which cache to remove itself from.
# It can't simply call _async_clients() because we may be in
# __new__(AsyncHTTPClient) but instance.__class__ may be
# SimpleAsyncHTTPClient.
instance._instance_cache = instance_cache
if instance_cache is not None:
instance_cache[instance.io_loop] = instance
return instance
def initialize(self, defaults: Dict[str, Any] = None) -> None:
self.io_loop = IOLoop.current()
self.defaults = dict(HTTPRequest._DEFAULTS)
if defaults is not None:
self.defaults.update(defaults)
self._closed = False
def close(self) -> None:
"""Destroys this HTTP client, freeing any file descriptors used.
This method is **not needed in normal use** due to the way
that `AsyncHTTPClient` objects are transparently reused.
``close()`` is generally only necessary when either the
`.IOLoop` is also being closed, or the ``force_instance=True``
argument was used when creating the `AsyncHTTPClient`.
No other methods may be called on the `AsyncHTTPClient` after
``close()``.
"""
if self._closed:
return
self._closed = True
if self._instance_cache is not None:
cached_val = self._instance_cache.pop(self.io_loop, None)
# If there's an object other than self in the instance
# cache for our IOLoop, something has gotten mixed up. A
# value of None appears to be possible when this is called
# from a destructor (HTTPClient.__del__) as the weakref
# gets cleared before the destructor runs.
if cached_val is not None and cached_val is not self:
raise RuntimeError("inconsistent AsyncHTTPClient cache")
def fetch(
self,
request: Union[str, "HTTPRequest"],
raise_error: bool = True,
**kwargs: Any
) -> Awaitable["HTTPResponse"]:
"""Executes a request, asynchronously returning an `HTTPResponse`.
The request may be either a string URL or an `HTTPRequest` object.
If it is a string, we construct an `HTTPRequest` using any additional
kwargs: ``HTTPRequest(request, **kwargs)``
This method returns a `.Future` whose result is an
`HTTPResponse`. By default, the ``Future`` will raise an
`HTTPError` if the request returned a non-200 response code
(other errors may also be raised if the server could not be
contacted). Instead, if ``raise_error`` is set to False, the
response will always be returned regardless of the response
code.
If a ``callback`` is given, it will be invoked with the `HTTPResponse`.
In the callback interface, `HTTPError` is not automatically raised.
Instead, you must check the response's ``error`` attribute or
call its `~HTTPResponse.rethrow` method.
.. versionchanged:: 6.0
The ``callback`` argument was removed. Use the returned
`.Future` instead.
The ``raise_error=False`` argument only affects the
`HTTPError` raised when a non-200 response code is used,
instead of suppressing all errors.
"""
if self._closed:
raise RuntimeError("fetch() called on closed AsyncHTTPClient")
if not isinstance(request, HTTPRequest):
request = HTTPRequest(url=request, **kwargs)
else:
if kwargs:
raise ValueError(
"kwargs can't be used if request is an HTTPRequest object"
)
# We may modify this (to add Host, Accept-Encoding, etc),
# so make sure we don't modify the caller's object. This is also
# where normal dicts get converted to HTTPHeaders objects.
request.headers = httputil.HTTPHeaders(request.headers)
request_proxy = _RequestProxy(request, self.defaults)
future = Future() # type: Future[HTTPResponse]
def handle_response(response: "HTTPResponse") -> None:
if response.error:
if raise_error or not response._error_is_response_code:
future_set_exception_unless_cancelled(future, response.error)
return
future_set_result_unless_cancelled(future, response)
self.fetch_impl(cast(HTTPRequest, request_proxy), handle_response)
return future
def fetch_impl(
self, request: "HTTPRequest", callback: Callable[["HTTPResponse"], None]
) -> None:
raise NotImplementedError()
@classmethod
def configure(
cls, impl: "Union[None, str, Type[Configurable]]", **kwargs: Any
) -> None:
"""Configures the `AsyncHTTPClient` subclass to use.
``AsyncHTTPClient()`` actually creates an instance of a subclass.
This method may be called with either a class object or the
fully-qualified name of such a class (or ``None`` to use the default,
``SimpleAsyncHTTPClient``)
If additional keyword arguments are given, they will be passed
to the constructor of each subclass instance created. The
keyword argument ``max_clients`` determines the maximum number
of simultaneous `~AsyncHTTPClient.fetch()` operations that can
execute in parallel on each `.IOLoop`. Additional arguments
may be supported depending on the implementation class in use.
Example::
AsyncHTTPClient.configure("tornado.curl_httpclient.CurlAsyncHTTPClient")
"""
super(AsyncHTTPClient, cls).configure(impl, **kwargs)
class HTTPRequest(object):
"""HTTP client request object."""
_headers = None # type: Union[Dict[str, str], httputil.HTTPHeaders]
# Default values for HTTPRequest parameters.
# Merged with the values on the request object by AsyncHTTPClient
# implementations.
_DEFAULTS = dict(
connect_timeout=20.0,
request_timeout=20.0,
follow_redirects=True,
max_redirects=5,
decompress_response=True,
proxy_password="",
allow_nonstandard_methods=False,
validate_cert=True,
)
def __init__(
self,
url: str,
method: str = "GET",
headers: Union[Dict[str, str], httputil.HTTPHeaders] = None,
body: Union[bytes, str] = None,
auth_username: str = None,
auth_password: str = None,
auth_mode: str = None,
connect_timeout: float = None,
request_timeout: float = None,
if_modified_since: Union[float, datetime.datetime] = None,
follow_redirects: bool = None,
max_redirects: int = None,
user_agent: str = None,
use_gzip: bool = None,
network_interface: str = None,
streaming_callback: Callable[[bytes], None] = None,
header_callback: Callable[[str], None] = None,
prepare_curl_callback: Callable[[Any], None] = None,
proxy_host: str = None,
proxy_port: int = None,
proxy_username: str = None,
proxy_password: str = None,
proxy_auth_mode: str = None,
allow_nonstandard_methods: bool = None,
validate_cert: bool = None,
ca_certs: str = None,
allow_ipv6: bool = None,
client_key: str = None,
client_cert: str = None,
body_producer: Callable[[Callable[[bytes], None]], "Future[None]"] = None,
expect_100_continue: bool = False,
decompress_response: bool = None,
ssl_options: Union[Dict[str, Any], ssl.SSLContext] = None,
) -> None:
r"""All parameters except ``url`` are optional.
:arg str url: URL to fetch
:arg str method: HTTP method, e.g. "GET" or "POST"
:arg headers: Additional HTTP headers to pass on the request
:type headers: `~tornado.httputil.HTTPHeaders` or `dict`
:arg body: HTTP request body as a string (byte or unicode; if unicode
the utf-8 encoding will be used)
:arg body_producer: Callable used for lazy/asynchronous request bodies.
It is called with one argument, a ``write`` function, and should
return a `.Future`. It should call the write function with new
data as it becomes available. The write function returns a
`.Future` which can be used for flow control.
Only one of ``body`` and ``body_producer`` may
be specified. ``body_producer`` is not supported on
``curl_httpclient``. When using ``body_producer`` it is recommended
to pass a ``Content-Length`` in the headers as otherwise chunked
encoding will be used, and many servers do not support chunked
encoding on requests. New in Tornado 4.0
:arg str auth_username: Username for HTTP authentication
:arg str auth_password: Password for HTTP authentication
:arg str auth_mode: Authentication mode; default is "basic".
Allowed values are implementation-defined; ``curl_httpclient``
supports "basic" and "digest"; ``simple_httpclient`` only supports
"basic"
:arg float connect_timeout: Timeout for initial connection in seconds,
default 20 seconds
:arg float request_timeout: Timeout for entire request in seconds,
default 20 seconds
:arg if_modified_since: Timestamp for ``If-Modified-Since`` header
:type if_modified_since: `datetime` or `float`
:arg bool follow_redirects: Should redirects be followed automatically
or return the 3xx response? Default True.
:arg int max_redirects: Limit for ``follow_redirects``, default 5.
:arg str user_agent: String to send as ``User-Agent`` header
:arg bool decompress_response: Request a compressed response from
the server and decompress it after downloading. Default is True.
New in Tornado 4.0.
:arg bool use_gzip: Deprecated alias for ``decompress_response``
since Tornado 4.0.
:arg str network_interface: Network interface or source IP to use for request.
See ``curl_httpclient`` note below.
:arg collections.abc.Callable streaming_callback: If set, ``streaming_callback`` will
be run with each chunk of data as it is received, and
``HTTPResponse.body`` and ``HTTPResponse.buffer`` will be empty in
the final response.
:arg collections.abc.Callable header_callback: If set, ``header_callback`` will
be run with each header line as it is received (including the
first line, e.g. ``HTTP/1.0 200 OK\r\n``, and a final line
containing only ``\r\n``. All lines include the trailing newline
characters). ``HTTPResponse.headers`` will be empty in the final
response. This is most useful in conjunction with
``streaming_callback``, because it's the only way to get access to
header data while the request is in progress.
:arg collections.abc.Callable prepare_curl_callback: If set, will be called with
a ``pycurl.Curl`` object to allow the application to make additional
``setopt`` calls.
:arg str proxy_host: HTTP proxy hostname. To use proxies,
``proxy_host`` and ``proxy_port`` must be set; ``proxy_username``,
``proxy_pass`` and ``proxy_auth_mode`` are optional. Proxies are
currently only supported with ``curl_httpclient``.
:arg int proxy_port: HTTP proxy port
:arg str proxy_username: HTTP proxy username
:arg str proxy_password: HTTP proxy password
:arg str proxy_auth_mode: HTTP proxy Authentication mode;
default is "basic". supports "basic" and "digest"
:arg bool allow_nonstandard_methods: Allow unknown values for ``method``
argument? Default is False.
:arg bool validate_cert: For HTTPS requests, validate the server's
certificate? Default is True.
:arg str ca_certs: filename of CA certificates in PEM format,
or None to use defaults. See note below when used with
``curl_httpclient``.
:arg str client_key: Filename for client SSL key, if any. See
note below when used with ``curl_httpclient``.
:arg str client_cert: Filename for client SSL certificate, if any.
See note below when used with ``curl_httpclient``.
:arg ssl.SSLContext ssl_options: `ssl.SSLContext` object for use in
``simple_httpclient`` (unsupported by ``curl_httpclient``).
Overrides ``validate_cert``, ``ca_certs``, ``client_key``,
and ``client_cert``.
:arg bool allow_ipv6: Use IPv6 when available? Default is True.
:arg bool expect_100_continue: If true, send the
``Expect: 100-continue`` header and wait for a continue response
before sending the request body. Only supported with
``simple_httpclient``.
.. note::
When using ``curl_httpclient`` certain options may be
inherited by subsequent fetches because ``pycurl`` does
not allow them to be cleanly reset. This applies to the
``ca_certs``, ``client_key``, ``client_cert``, and
``network_interface`` arguments. If you use these
options, you should pass them on every request (you don't
have to always use the same values, but it's not possible
to mix requests that specify these options with ones that
use the defaults).
.. versionadded:: 3.1
The ``auth_mode`` argument.
.. versionadded:: 4.0
The ``body_producer`` and ``expect_100_continue`` arguments.
.. versionadded:: 4.2
The ``ssl_options`` argument.
.. versionadded:: 4.5
The ``proxy_auth_mode`` argument.
"""
# Note that some of these attributes go through property setters
# defined below.
self.headers = headers
if if_modified_since:
self.headers["If-Modified-Since"] = httputil.format_timestamp(
if_modified_since
)
self.proxy_host = proxy_host
self.proxy_port = proxy_port
self.proxy_username = proxy_username
self.proxy_password = proxy_password
self.proxy_auth_mode = proxy_auth_mode
self.url = url
self.method = method
self.body = body
self.body_producer = body_producer
self.auth_username = auth_username
self.auth_password = auth_password
self.auth_mode = auth_mode
self.connect_timeout = connect_timeout
self.request_timeout = request_timeout
self.follow_redirects = follow_redirects
self.max_redirects = max_redirects
self.user_agent = user_agent
if decompress_response is not None:
self.decompress_response = decompress_response # type: Optional[bool]
else:
self.decompress_response = use_gzip
self.network_interface = network_interface
self.streaming_callback = streaming_callback
self.header_callback = header_callback
self.prepare_curl_callback = prepare_curl_callback
self.allow_nonstandard_methods = allow_nonstandard_methods
self.validate_cert = validate_cert
self.ca_certs = ca_certs
self.allow_ipv6 = allow_ipv6
self.client_key = client_key
self.client_cert = client_cert
self.ssl_options = ssl_options
self.expect_100_continue = expect_100_continue
self.start_time = time.time()
@property
def headers(self) -> httputil.HTTPHeaders:
# TODO: headers may actually be a plain dict until fairly late in
# the process (AsyncHTTPClient.fetch), but practically speaking,
# whenever the property is used they're already HTTPHeaders.
return self._headers # type: ignore
@headers.setter
def headers(self, value: Union[Dict[str, str], httputil.HTTPHeaders]) -> None:
if value is None:
self._headers = httputil.HTTPHeaders()
else:
self._headers = value # type: ignore
@property
def body(self) -> bytes:
return self._body
@body.setter
def body(self, value: Union[bytes, str]) -> None:
self._body = utf8(value)
class HTTPResponse(object):
"""HTTP Response object.
Attributes:
* ``request``: HTTPRequest object
* ``code``: numeric HTTP status code, e.g. 200 or 404
* ``reason``: human-readable reason phrase describing the status code
* ``headers``: `tornado.httputil.HTTPHeaders` object
* ``effective_url``: final location of the resource after following any
redirects
* ``buffer``: ``cStringIO`` object for response body
* ``body``: response body as bytes (created on demand from ``self.buffer``)
* ``error``: Exception object, if any
* ``request_time``: seconds from request start to finish. Includes all
network operations from DNS resolution to receiving the last byte of
data. Does not include time spent in the queue (due to the
``max_clients`` option). If redirects were followed, only includes
the final request.
* ``start_time``: Time at which the HTTP operation started, based on
`time.time` (not the monotonic clock used by `.IOLoop.time`). May
be ``None`` if the request timed out while in the queue.
* ``time_info``: dictionary of diagnostic timing information from the
request. Available data are subject to change, but currently uses timings
available from http://curl.haxx.se/libcurl/c/curl_easy_getinfo.html,
plus ``queue``, which is the delay (if any) introduced by waiting for
a slot under `AsyncHTTPClient`'s ``max_clients`` setting.
.. versionadded:: 5.1
Added the ``start_time`` attribute.
.. versionchanged:: 5.1
The ``request_time`` attribute previously included time spent in the queue
for ``simple_httpclient``, but not in ``curl_httpclient``. Now queueing time
is excluded in both implementations. ``request_time`` is now more accurate for
``curl_httpclient`` because it uses a monotonic clock when available.
"""
# I'm not sure why these don't get type-inferred from the references in __init__.
error = None # type: Optional[BaseException]
_error_is_response_code = False
request = None # type: HTTPRequest
def __init__(
self,
request: HTTPRequest,
code: int,
headers: httputil.HTTPHeaders = None,
buffer: BytesIO = None,
effective_url: str = None,
error: BaseException = None,
request_time: float = None,
time_info: Dict[str, float] = None,
reason: str = None,
start_time: float = None,
) -> None:
if isinstance(request, _RequestProxy):
self.request = request.request
else:
self.request = request
self.code = code
self.reason = reason or httputil.responses.get(code, "Unknown")
if headers is not None:
self.headers = headers
else:
self.headers = httputil.HTTPHeaders()
self.buffer = buffer
self._body = None # type: Optional[bytes]
if effective_url is None:
self.effective_url = request.url
else:
self.effective_url = effective_url
self._error_is_response_code = False
if error is None:
if self.code < 200 or self.code >= 300:
self._error_is_response_code = True
self.error = HTTPError(self.code, message=self.reason, response=self)
else:
self.error = None
else:
self.error = error
self.start_time = start_time
self.request_time = request_time
self.time_info = time_info or {}
@property
def body(self) -> bytes:
if self.buffer is None:
return b""
elif self._body is None:
self._body = self.buffer.getvalue()
return self._body
def rethrow(self) -> None:
"""If there was an error on the request, raise an `HTTPError`."""
if self.error:
raise self.error
def __repr__(self) -> str:
args = ",".join("%s=%r" % i for i in sorted(self.__dict__.items()))
return "%s(%s)" % (self.__class__.__name__, args)
class HTTPClientError(Exception):
"""Exception thrown for an unsuccessful HTTP request.
Attributes:
* ``code`` - HTTP error integer error code, e.g. 404. Error code 599 is
used when no HTTP response was received, e.g. for a timeout.
* ``response`` - `HTTPResponse` object, if any.
Note that if ``follow_redirects`` is False, redirects become HTTPErrors,
and you can look at ``error.response.headers['Location']`` to see the
destination of the redirect.
.. versionchanged:: 5.1
Renamed from ``HTTPError`` to ``HTTPClientError`` to avoid collisions with
`tornado.web.HTTPError`. The name ``tornado.httpclient.HTTPError`` remains
as an alias.
"""
def __init__(
self, code: int, message: str = None, response: HTTPResponse = None
) -> None:
self.code = code
self.message = message or httputil.responses.get(code, "Unknown")
self.response = response
super(HTTPClientError, self).__init__(code, message, response)
def __str__(self) -> str:
return "HTTP %d: %s" % (self.code, self.message)
# There is a cyclic reference between self and self.response,
# which breaks the default __repr__ implementation.
# (especially on pypy, which doesn't have the same recursion
# detection as cpython).
__repr__ = __str__
HTTPError = HTTPClientError
class _RequestProxy(object):
"""Combines an object with a dictionary of defaults.
Used internally by AsyncHTTPClient implementations.
"""
def __init__(
self, request: HTTPRequest, defaults: Optional[Dict[str, Any]]
) -> None:
self.request = request
self.defaults = defaults
def __getattr__(self, name: str) -> Any:
request_attr = getattr(self.request, name)
if request_attr is not None:
return request_attr
elif self.defaults is not None:
return self.defaults.get(name, None)
else:
return None
def main() -> None:
from tornado.options import define, options, parse_command_line
define("print_headers", type=bool, default=False)
define("print_body", type=bool, default=True)
define("follow_redirects", type=bool, default=True)
define("validate_cert", type=bool, default=True)
define("proxy_host", type=str)
define("proxy_port", type=int)
args = parse_command_line()
client = HTTPClient()
for arg in args:
try:
response = client.fetch(
arg,
follow_redirects=options.follow_redirects,
validate_cert=options.validate_cert,
proxy_host=options.proxy_host,
proxy_port=options.proxy_port,
)
except HTTPError as e:
if e.response is not None:
response = e.response
else:
raise
if options.print_headers:
print(response.headers)
if options.print_body:
print(native_str(response.body))
client.close()
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,398 @@
#
# Copyright 2009 Facebook
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""A non-blocking, single-threaded HTTP server.
Typical applications have little direct interaction with the `HTTPServer`
class except to start a server at the beginning of the process
(and even that is often done indirectly via `tornado.web.Application.listen`).
.. versionchanged:: 4.0
The ``HTTPRequest`` class that used to live in this module has been moved
to `tornado.httputil.HTTPServerRequest`. The old name remains as an alias.
"""
import socket
import ssl
from tornado.escape import native_str
from tornado.http1connection import HTTP1ServerConnection, HTTP1ConnectionParameters
from tornado import httputil
from tornado import iostream
from tornado import netutil
from tornado.tcpserver import TCPServer
from tornado.util import Configurable
import typing
from typing import Union, Any, Dict, Callable, List, Type, Tuple, Optional, Awaitable
if typing.TYPE_CHECKING:
from typing import Set # noqa: F401
class HTTPServer(TCPServer, Configurable, httputil.HTTPServerConnectionDelegate):
r"""A non-blocking, single-threaded HTTP server.
A server is defined by a subclass of `.HTTPServerConnectionDelegate`,
or, for backwards compatibility, a callback that takes an
`.HTTPServerRequest` as an argument. The delegate is usually a
`tornado.web.Application`.
`HTTPServer` supports keep-alive connections by default
(automatically for HTTP/1.1, or for HTTP/1.0 when the client
requests ``Connection: keep-alive``).
If ``xheaders`` is ``True``, we support the
``X-Real-Ip``/``X-Forwarded-For`` and
``X-Scheme``/``X-Forwarded-Proto`` headers, which override the
remote IP and URI scheme/protocol for all requests. These headers
are useful when running Tornado behind a reverse proxy or load
balancer. The ``protocol`` argument can also be set to ``https``
if Tornado is run behind an SSL-decoding proxy that does not set one of
the supported ``xheaders``.
By default, when parsing the ``X-Forwarded-For`` header, Tornado will
select the last (i.e., the closest) address on the list of hosts as the
remote host IP address. To select the next server in the chain, a list of
trusted downstream hosts may be passed as the ``trusted_downstream``
argument. These hosts will be skipped when parsing the ``X-Forwarded-For``
header.
To make this server serve SSL traffic, send the ``ssl_options`` keyword
argument with an `ssl.SSLContext` object. For compatibility with older
versions of Python ``ssl_options`` may also be a dictionary of keyword
arguments for the `ssl.wrap_socket` method.::
ssl_ctx = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
ssl_ctx.load_cert_chain(os.path.join(data_dir, "mydomain.crt"),
os.path.join(data_dir, "mydomain.key"))
HTTPServer(application, ssl_options=ssl_ctx)
`HTTPServer` initialization follows one of three patterns (the
initialization methods are defined on `tornado.tcpserver.TCPServer`):
1. `~tornado.tcpserver.TCPServer.listen`: simple single-process::
server = HTTPServer(app)
server.listen(8888)
IOLoop.current().start()
In many cases, `tornado.web.Application.listen` can be used to avoid
the need to explicitly create the `HTTPServer`.
2. `~tornado.tcpserver.TCPServer.bind`/`~tornado.tcpserver.TCPServer.start`:
simple multi-process::
server = HTTPServer(app)
server.bind(8888)
server.start(0) # Forks multiple sub-processes
IOLoop.current().start()
When using this interface, an `.IOLoop` must *not* be passed
to the `HTTPServer` constructor. `~.TCPServer.start` will always start
the server on the default singleton `.IOLoop`.
3. `~tornado.tcpserver.TCPServer.add_sockets`: advanced multi-process::
sockets = tornado.netutil.bind_sockets(8888)
tornado.process.fork_processes(0)
server = HTTPServer(app)
server.add_sockets(sockets)
IOLoop.current().start()
The `~.TCPServer.add_sockets` interface is more complicated,
but it can be used with `tornado.process.fork_processes` to
give you more flexibility in when the fork happens.
`~.TCPServer.add_sockets` can also be used in single-process
servers if you want to create your listening sockets in some
way other than `tornado.netutil.bind_sockets`.
.. versionchanged:: 4.0
Added ``decompress_request``, ``chunk_size``, ``max_header_size``,
``idle_connection_timeout``, ``body_timeout``, ``max_body_size``
arguments. Added support for `.HTTPServerConnectionDelegate`
instances as ``request_callback``.
.. versionchanged:: 4.1
`.HTTPServerConnectionDelegate.start_request` is now called with
two arguments ``(server_conn, request_conn)`` (in accordance with the
documentation) instead of one ``(request_conn)``.
.. versionchanged:: 4.2
`HTTPServer` is now a subclass of `tornado.util.Configurable`.
.. versionchanged:: 4.5
Added the ``trusted_downstream`` argument.
.. versionchanged:: 5.0
The ``io_loop`` argument has been removed.
"""
def __init__(self, *args: Any, **kwargs: Any) -> None:
# Ignore args to __init__; real initialization belongs in
# initialize since we're Configurable. (there's something
# weird in initialization order between this class,
# Configurable, and TCPServer so we can't leave __init__ out
# completely)
pass
def initialize(
self,
request_callback: Union[
httputil.HTTPServerConnectionDelegate,
Callable[[httputil.HTTPServerRequest], None],
],
no_keep_alive: bool = False,
xheaders: bool = False,
ssl_options: Union[Dict[str, Any], ssl.SSLContext] = None,
protocol: str = None,
decompress_request: bool = False,
chunk_size: int = None,
max_header_size: int = None,
idle_connection_timeout: float = None,
body_timeout: float = None,
max_body_size: int = None,
max_buffer_size: int = None,
trusted_downstream: List[str] = None,
) -> None:
# This method's signature is not extracted with autodoc
# because we want its arguments to appear on the class
# constructor. When changing this signature, also update the
# copy in httpserver.rst.
self.request_callback = request_callback
self.xheaders = xheaders
self.protocol = protocol
self.conn_params = HTTP1ConnectionParameters(
decompress=decompress_request,
chunk_size=chunk_size,
max_header_size=max_header_size,
header_timeout=idle_connection_timeout or 3600,
max_body_size=max_body_size,
body_timeout=body_timeout,
no_keep_alive=no_keep_alive,
)
TCPServer.__init__(
self,
ssl_options=ssl_options,
max_buffer_size=max_buffer_size,
read_chunk_size=chunk_size,
)
self._connections = set() # type: Set[HTTP1ServerConnection]
self.trusted_downstream = trusted_downstream
@classmethod
def configurable_base(cls) -> Type[Configurable]:
return HTTPServer
@classmethod
def configurable_default(cls) -> Type[Configurable]:
return HTTPServer
async def close_all_connections(self) -> None:
"""Close all open connections and asynchronously wait for them to finish.
This method is used in combination with `~.TCPServer.stop` to
support clean shutdowns (especially for unittests). Typical
usage would call ``stop()`` first to stop accepting new
connections, then ``await close_all_connections()`` to wait for
existing connections to finish.
This method does not currently close open websocket connections.
Note that this method is a coroutine and must be caled with ``await``.
"""
while self._connections:
# Peek at an arbitrary element of the set
conn = next(iter(self._connections))
await conn.close()
def handle_stream(self, stream: iostream.IOStream, address: Tuple) -> None:
context = _HTTPRequestContext(
stream, address, self.protocol, self.trusted_downstream
)
conn = HTTP1ServerConnection(stream, self.conn_params, context)
self._connections.add(conn)
conn.start_serving(self)
def start_request(
self, server_conn: object, request_conn: httputil.HTTPConnection
) -> httputil.HTTPMessageDelegate:
if isinstance(self.request_callback, httputil.HTTPServerConnectionDelegate):
delegate = self.request_callback.start_request(server_conn, request_conn)
else:
delegate = _CallableAdapter(self.request_callback, request_conn)
if self.xheaders:
delegate = _ProxyAdapter(delegate, request_conn)
return delegate
def on_close(self, server_conn: object) -> None:
self._connections.remove(typing.cast(HTTP1ServerConnection, server_conn))
class _CallableAdapter(httputil.HTTPMessageDelegate):
def __init__(
self,
request_callback: Callable[[httputil.HTTPServerRequest], None],
request_conn: httputil.HTTPConnection,
) -> None:
self.connection = request_conn
self.request_callback = request_callback
self.request = None # type: Optional[httputil.HTTPServerRequest]
self.delegate = None
self._chunks = [] # type: List[bytes]
def headers_received(
self,
start_line: Union[httputil.RequestStartLine, httputil.ResponseStartLine],
headers: httputil.HTTPHeaders,
) -> Optional[Awaitable[None]]:
self.request = httputil.HTTPServerRequest(
connection=self.connection,
start_line=typing.cast(httputil.RequestStartLine, start_line),
headers=headers,
)
return None
def data_received(self, chunk: bytes) -> Optional[Awaitable[None]]:
self._chunks.append(chunk)
return None
def finish(self) -> None:
assert self.request is not None
self.request.body = b"".join(self._chunks)
self.request._parse_body()
self.request_callback(self.request)
def on_connection_close(self) -> None:
del self._chunks
class _HTTPRequestContext(object):
def __init__(
self,
stream: iostream.IOStream,
address: Tuple,
protocol: Optional[str],
trusted_downstream: List[str] = None,
) -> None:
self.address = address
# Save the socket's address family now so we know how to
# interpret self.address even after the stream is closed
# and its socket attribute replaced with None.
if stream.socket is not None:
self.address_family = stream.socket.family
else:
self.address_family = None
# In HTTPServerRequest we want an IP, not a full socket address.
if (
self.address_family in (socket.AF_INET, socket.AF_INET6)
and address is not None
):
self.remote_ip = address[0]
else:
# Unix (or other) socket; fake the remote address.
self.remote_ip = "0.0.0.0"
if protocol:
self.protocol = protocol
elif isinstance(stream, iostream.SSLIOStream):
self.protocol = "https"
else:
self.protocol = "http"
self._orig_remote_ip = self.remote_ip
self._orig_protocol = self.protocol
self.trusted_downstream = set(trusted_downstream or [])
def __str__(self) -> str:
if self.address_family in (socket.AF_INET, socket.AF_INET6):
return self.remote_ip
elif isinstance(self.address, bytes):
# Python 3 with the -bb option warns about str(bytes),
# so convert it explicitly.
# Unix socket addresses are str on mac but bytes on linux.
return native_str(self.address)
else:
return str(self.address)
def _apply_xheaders(self, headers: httputil.HTTPHeaders) -> None:
"""Rewrite the ``remote_ip`` and ``protocol`` fields."""
# Squid uses X-Forwarded-For, others use X-Real-Ip
ip = headers.get("X-Forwarded-For", self.remote_ip)
# Skip trusted downstream hosts in X-Forwarded-For list
for ip in (cand.strip() for cand in reversed(ip.split(","))):
if ip not in self.trusted_downstream:
break
ip = headers.get("X-Real-Ip", ip)
if netutil.is_valid_ip(ip):
self.remote_ip = ip
# AWS uses X-Forwarded-Proto
proto_header = headers.get(
"X-Scheme", headers.get("X-Forwarded-Proto", self.protocol)
)
if proto_header:
# use only the last proto entry if there is more than one
# TODO: support trusting mutiple layers of proxied protocol
proto_header = proto_header.split(",")[-1].strip()
if proto_header in ("http", "https"):
self.protocol = proto_header
def _unapply_xheaders(self) -> None:
"""Undo changes from `_apply_xheaders`.
Xheaders are per-request so they should not leak to the next
request on the same connection.
"""
self.remote_ip = self._orig_remote_ip
self.protocol = self._orig_protocol
class _ProxyAdapter(httputil.HTTPMessageDelegate):
def __init__(
self,
delegate: httputil.HTTPMessageDelegate,
request_conn: httputil.HTTPConnection,
) -> None:
self.connection = request_conn
self.delegate = delegate
def headers_received(
self,
start_line: Union[httputil.RequestStartLine, httputil.ResponseStartLine],
headers: httputil.HTTPHeaders,
) -> Optional[Awaitable[None]]:
# TODO: either make context an official part of the
# HTTPConnection interface or figure out some other way to do this.
self.connection.context._apply_xheaders(headers) # type: ignore
return self.delegate.headers_received(start_line, headers)
def data_received(self, chunk: bytes) -> Optional[Awaitable[None]]:
return self.delegate.data_received(chunk)
def finish(self) -> None:
self.delegate.finish()
self._cleanup()
def on_connection_close(self) -> None:
self.delegate.on_connection_close()
self._cleanup()
def _cleanup(self) -> None:
self.connection.context._unapply_xheaders() # type: ignore
HTTPRequest = httputil.HTTPServerRequest

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,946 @@
#
# Copyright 2009 Facebook
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""An I/O event loop for non-blocking sockets.
In Tornado 6.0, `.IOLoop` is a wrapper around the `asyncio` event
loop, with a slightly different interface for historical reasons.
Applications can use either the `.IOLoop` interface or the underlying
`asyncio` event loop directly (unless compatibility with older
versions of Tornado is desired, in which case `.IOLoop` must be used).
Typical applications will use a single `IOLoop` object, accessed via
`IOLoop.current` class method. The `IOLoop.start` method (or
equivalently, `asyncio.AbstractEventLoop.run_forever`) should usually
be called at the end of the ``main()`` function. Atypical applications
may use more than one `IOLoop`, such as one `IOLoop` per thread, or
per `unittest` case.
"""
import asyncio
import concurrent.futures
import datetime
import functools
import logging
import numbers
import os
import sys
import time
import math
import random
from tornado.concurrent import (
Future,
is_future,
chain_future,
future_set_exc_info,
future_add_done_callback,
)
from tornado.log import app_log
from tornado.util import Configurable, TimeoutError, import_object
import typing
from typing import Union, Any, Type, Optional, Callable, TypeVar, Tuple, Awaitable
if typing.TYPE_CHECKING:
from typing import Dict, List # noqa: F401
from typing_extensions import Protocol
else:
Protocol = object
class _Selectable(Protocol):
def fileno(self) -> int:
pass
def close(self) -> None:
pass
_T = TypeVar("_T")
_S = TypeVar("_S", bound=_Selectable)
class IOLoop(Configurable):
"""An I/O event loop.
As of Tornado 6.0, `IOLoop` is a wrapper around the `asyncio` event
loop.
Example usage for a simple TCP server:
.. testcode::
import errno
import functools
import socket
import tornado.ioloop
from tornado.iostream import IOStream
async def handle_connection(connection, address):
stream = IOStream(connection)
message = await stream.read_until_close()
print("message from client:", message.decode().strip())
def connection_ready(sock, fd, events):
while True:
try:
connection, address = sock.accept()
except socket.error as e:
if e.args[0] not in (errno.EWOULDBLOCK, errno.EAGAIN):
raise
return
connection.setblocking(0)
io_loop = tornado.ioloop.IOLoop.current()
io_loop.spawn_callback(handle_connection, connection, address)
if __name__ == '__main__':
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.setblocking(0)
sock.bind(("", 8888))
sock.listen(128)
io_loop = tornado.ioloop.IOLoop.current()
callback = functools.partial(connection_ready, sock)
io_loop.add_handler(sock.fileno(), callback, io_loop.READ)
io_loop.start()
.. testoutput::
:hide:
By default, a newly-constructed `IOLoop` becomes the thread's current
`IOLoop`, unless there already is a current `IOLoop`. This behavior
can be controlled with the ``make_current`` argument to the `IOLoop`
constructor: if ``make_current=True``, the new `IOLoop` will always
try to become current and it raises an error if there is already a
current instance. If ``make_current=False``, the new `IOLoop` will
not try to become current.
In general, an `IOLoop` cannot survive a fork or be shared across
processes in any way. When multiple processes are being used, each
process should create its own `IOLoop`, which also implies that
any objects which depend on the `IOLoop` (such as
`.AsyncHTTPClient`) must also be created in the child processes.
As a guideline, anything that starts processes (including the
`tornado.process` and `multiprocessing` modules) should do so as
early as possible, ideally the first thing the application does
after loading its configuration in ``main()``.
.. versionchanged:: 4.2
Added the ``make_current`` keyword argument to the `IOLoop`
constructor.
.. versionchanged:: 5.0
Uses the `asyncio` event loop by default. The
``IOLoop.configure`` method cannot be used on Python 3 except
to redundantly specify the `asyncio` event loop.
"""
# These constants were originally based on constants from the epoll module.
NONE = 0
READ = 0x001
WRITE = 0x004
ERROR = 0x018
# In Python 3, _ioloop_for_asyncio maps from asyncio loops to IOLoops.
_ioloop_for_asyncio = dict() # type: Dict[asyncio.AbstractEventLoop, IOLoop]
@classmethod
def configure(
cls, impl: "Union[None, str, Type[Configurable]]", **kwargs: Any
) -> None:
if asyncio is not None:
from tornado.platform.asyncio import BaseAsyncIOLoop
if isinstance(impl, str):
impl = import_object(impl)
if isinstance(impl, type) and not issubclass(impl, BaseAsyncIOLoop):
raise RuntimeError(
"only AsyncIOLoop is allowed when asyncio is available"
)
super(IOLoop, cls).configure(impl, **kwargs)
@staticmethod
def instance() -> "IOLoop":
"""Deprecated alias for `IOLoop.current()`.
.. versionchanged:: 5.0
Previously, this method returned a global singleton
`IOLoop`, in contrast with the per-thread `IOLoop` returned
by `current()`. In nearly all cases the two were the same
(when they differed, it was generally used from non-Tornado
threads to communicate back to the main thread's `IOLoop`).
This distinction is not present in `asyncio`, so in order
to facilitate integration with that package `instance()`
was changed to be an alias to `current()`. Applications
using the cross-thread communications aspect of
`instance()` should instead set their own global variable
to point to the `IOLoop` they want to use.
.. deprecated:: 5.0
"""
return IOLoop.current()
def install(self) -> None:
"""Deprecated alias for `make_current()`.
.. versionchanged:: 5.0
Previously, this method would set this `IOLoop` as the
global singleton used by `IOLoop.instance()`. Now that
`instance()` is an alias for `current()`, `install()`
is an alias for `make_current()`.
.. deprecated:: 5.0
"""
self.make_current()
@staticmethod
def clear_instance() -> None:
"""Deprecated alias for `clear_current()`.
.. versionchanged:: 5.0
Previously, this method would clear the `IOLoop` used as
the global singleton by `IOLoop.instance()`. Now that
`instance()` is an alias for `current()`,
`clear_instance()` is an alias for `clear_current()`.
.. deprecated:: 5.0
"""
IOLoop.clear_current()
@typing.overload
@staticmethod
def current() -> "IOLoop":
pass
@typing.overload # noqa: F811
@staticmethod
def current(instance: bool = True) -> Optional["IOLoop"]:
pass
@staticmethod # noqa: F811
def current(instance: bool = True) -> Optional["IOLoop"]:
"""Returns the current thread's `IOLoop`.
If an `IOLoop` is currently running or has been marked as
current by `make_current`, returns that instance. If there is
no current `IOLoop` and ``instance`` is true, creates one.
.. versionchanged:: 4.1
Added ``instance`` argument to control the fallback to
`IOLoop.instance()`.
.. versionchanged:: 5.0
On Python 3, control of the current `IOLoop` is delegated
to `asyncio`, with this and other methods as pass-through accessors.
The ``instance`` argument now controls whether an `IOLoop`
is created automatically when there is none, instead of
whether we fall back to `IOLoop.instance()` (which is now
an alias for this method). ``instance=False`` is deprecated,
since even if we do not create an `IOLoop`, this method
may initialize the asyncio loop.
"""
try:
loop = asyncio.get_event_loop()
except (RuntimeError, AssertionError):
if not instance:
return None
raise
try:
return IOLoop._ioloop_for_asyncio[loop]
except KeyError:
if instance:
from tornado.platform.asyncio import AsyncIOMainLoop
current = AsyncIOMainLoop(make_current=True) # type: Optional[IOLoop]
else:
current = None
return current
def make_current(self) -> None:
"""Makes this the `IOLoop` for the current thread.
An `IOLoop` automatically becomes current for its thread
when it is started, but it is sometimes useful to call
`make_current` explicitly before starting the `IOLoop`,
so that code run at startup time can find the right
instance.
.. versionchanged:: 4.1
An `IOLoop` created while there is no current `IOLoop`
will automatically become current.
.. versionchanged:: 5.0
This method also sets the current `asyncio` event loop.
"""
# The asyncio event loops override this method.
raise NotImplementedError()
@staticmethod
def clear_current() -> None:
"""Clears the `IOLoop` for the current thread.
Intended primarily for use by test frameworks in between tests.
.. versionchanged:: 5.0
This method also clears the current `asyncio` event loop.
"""
old = IOLoop.current(instance=False)
if old is not None:
old._clear_current_hook()
if asyncio is None:
IOLoop._current.instance = None
def _clear_current_hook(self) -> None:
"""Instance method called when an IOLoop ceases to be current.
May be overridden by subclasses as a counterpart to make_current.
"""
pass
@classmethod
def configurable_base(cls) -> Type[Configurable]:
return IOLoop
@classmethod
def configurable_default(cls) -> Type[Configurable]:
from tornado.platform.asyncio import AsyncIOLoop
return AsyncIOLoop
def initialize(self, make_current: bool = None) -> None:
if make_current is None:
if IOLoop.current(instance=False) is None:
self.make_current()
elif make_current:
current = IOLoop.current(instance=False)
# AsyncIO loops can already be current by this point.
if current is not None and current is not self:
raise RuntimeError("current IOLoop already exists")
self.make_current()
def close(self, all_fds: bool = False) -> None:
"""Closes the `IOLoop`, freeing any resources used.
If ``all_fds`` is true, all file descriptors registered on the
IOLoop will be closed (not just the ones created by the
`IOLoop` itself).
Many applications will only use a single `IOLoop` that runs for the
entire lifetime of the process. In that case closing the `IOLoop`
is not necessary since everything will be cleaned up when the
process exits. `IOLoop.close` is provided mainly for scenarios
such as unit tests, which create and destroy a large number of
``IOLoops``.
An `IOLoop` must be completely stopped before it can be closed. This
means that `IOLoop.stop()` must be called *and* `IOLoop.start()` must
be allowed to return before attempting to call `IOLoop.close()`.
Therefore the call to `close` will usually appear just after
the call to `start` rather than near the call to `stop`.
.. versionchanged:: 3.1
If the `IOLoop` implementation supports non-integer objects
for "file descriptors", those objects will have their
``close`` method when ``all_fds`` is true.
"""
raise NotImplementedError()
@typing.overload
def add_handler(
self, fd: int, handler: Callable[[int, int], None], events: int
) -> None:
pass
@typing.overload # noqa: F811
def add_handler(
self, fd: _S, handler: Callable[[_S, int], None], events: int
) -> None:
pass
def add_handler( # noqa: F811
self, fd: Union[int, _Selectable], handler: Callable[..., None], events: int
) -> None:
"""Registers the given handler to receive the given events for ``fd``.
The ``fd`` argument may either be an integer file descriptor or
a file-like object with a ``fileno()`` and ``close()`` method.
The ``events`` argument is a bitwise or of the constants
``IOLoop.READ``, ``IOLoop.WRITE``, and ``IOLoop.ERROR``.
When an event occurs, ``handler(fd, events)`` will be run.
.. versionchanged:: 4.0
Added the ability to pass file-like objects in addition to
raw file descriptors.
"""
raise NotImplementedError()
def update_handler(self, fd: Union[int, _Selectable], events: int) -> None:
"""Changes the events we listen for ``fd``.
.. versionchanged:: 4.0
Added the ability to pass file-like objects in addition to
raw file descriptors.
"""
raise NotImplementedError()
def remove_handler(self, fd: Union[int, _Selectable]) -> None:
"""Stop listening for events on ``fd``.
.. versionchanged:: 4.0
Added the ability to pass file-like objects in addition to
raw file descriptors.
"""
raise NotImplementedError()
def start(self) -> None:
"""Starts the I/O loop.
The loop will run until one of the callbacks calls `stop()`, which
will make the loop stop after the current event iteration completes.
"""
raise NotImplementedError()
def _setup_logging(self) -> None:
"""The IOLoop catches and logs exceptions, so it's
important that log output be visible. However, python's
default behavior for non-root loggers (prior to python
3.2) is to print an unhelpful "no handlers could be
found" message rather than the actual log entry, so we
must explicitly configure logging if we've made it this
far without anything.
This method should be called from start() in subclasses.
"""
if not any(
[
logging.getLogger().handlers,
logging.getLogger("tornado").handlers,
logging.getLogger("tornado.application").handlers,
]
):
logging.basicConfig()
def stop(self) -> None:
"""Stop the I/O loop.
If the event loop is not currently running, the next call to `start()`
will return immediately.
Note that even after `stop` has been called, the `IOLoop` is not
completely stopped until `IOLoop.start` has also returned.
Some work that was scheduled before the call to `stop` may still
be run before the `IOLoop` shuts down.
"""
raise NotImplementedError()
def run_sync(self, func: Callable, timeout: float = None) -> Any:
"""Starts the `IOLoop`, runs the given function, and stops the loop.
The function must return either an awaitable object or
``None``. If the function returns an awaitable object, the
`IOLoop` will run until the awaitable is resolved (and
`run_sync()` will return the awaitable's result). If it raises
an exception, the `IOLoop` will stop and the exception will be
re-raised to the caller.
The keyword-only argument ``timeout`` may be used to set
a maximum duration for the function. If the timeout expires,
a `tornado.util.TimeoutError` is raised.
This method is useful to allow asynchronous calls in a
``main()`` function::
async def main():
# do stuff...
if __name__ == '__main__':
IOLoop.current().run_sync(main)
.. versionchanged:: 4.3
Returning a non-``None``, non-awaitable value is now an error.
.. versionchanged:: 5.0
If a timeout occurs, the ``func`` coroutine will be cancelled.
"""
future_cell = [None] # type: List[Optional[Future]]
def run() -> None:
try:
result = func()
if result is not None:
from tornado.gen import convert_yielded
result = convert_yielded(result)
except Exception:
fut = Future() # type: Future[Any]
future_cell[0] = fut
future_set_exc_info(fut, sys.exc_info())
else:
if is_future(result):
future_cell[0] = result
else:
fut = Future()
future_cell[0] = fut
fut.set_result(result)
assert future_cell[0] is not None
self.add_future(future_cell[0], lambda future: self.stop())
self.add_callback(run)
if timeout is not None:
def timeout_callback() -> None:
# If we can cancel the future, do so and wait on it. If not,
# Just stop the loop and return with the task still pending.
# (If we neither cancel nor wait for the task, a warning
# will be logged).
assert future_cell[0] is not None
if not future_cell[0].cancel():
self.stop()
timeout_handle = self.add_timeout(self.time() + timeout, timeout_callback)
self.start()
if timeout is not None:
self.remove_timeout(timeout_handle)
assert future_cell[0] is not None
if future_cell[0].cancelled() or not future_cell[0].done():
raise TimeoutError("Operation timed out after %s seconds" % timeout)
return future_cell[0].result()
def time(self) -> float:
"""Returns the current time according to the `IOLoop`'s clock.
The return value is a floating-point number relative to an
unspecified time in the past.
Historically, the IOLoop could be customized to use e.g.
`time.monotonic` instead of `time.time`, but this is not
currently supported and so this method is equivalent to
`time.time`.
"""
return time.time()
def add_timeout(
self,
deadline: Union[float, datetime.timedelta],
callback: Callable[..., None],
*args: Any,
**kwargs: Any
) -> object:
"""Runs the ``callback`` at the time ``deadline`` from the I/O loop.
Returns an opaque handle that may be passed to
`remove_timeout` to cancel.
``deadline`` may be a number denoting a time (on the same
scale as `IOLoop.time`, normally `time.time`), or a
`datetime.timedelta` object for a deadline relative to the
current time. Since Tornado 4.0, `call_later` is a more
convenient alternative for the relative case since it does not
require a timedelta object.
Note that it is not safe to call `add_timeout` from other threads.
Instead, you must use `add_callback` to transfer control to the
`IOLoop`'s thread, and then call `add_timeout` from there.
Subclasses of IOLoop must implement either `add_timeout` or
`call_at`; the default implementations of each will call
the other. `call_at` is usually easier to implement, but
subclasses that wish to maintain compatibility with Tornado
versions prior to 4.0 must use `add_timeout` instead.
.. versionchanged:: 4.0
Now passes through ``*args`` and ``**kwargs`` to the callback.
"""
if isinstance(deadline, numbers.Real):
return self.call_at(deadline, callback, *args, **kwargs)
elif isinstance(deadline, datetime.timedelta):
return self.call_at(
self.time() + deadline.total_seconds(), callback, *args, **kwargs
)
else:
raise TypeError("Unsupported deadline %r" % deadline)
def call_later(
self, delay: float, callback: Callable[..., None], *args: Any, **kwargs: Any
) -> object:
"""Runs the ``callback`` after ``delay`` seconds have passed.
Returns an opaque handle that may be passed to `remove_timeout`
to cancel. Note that unlike the `asyncio` method of the same
name, the returned object does not have a ``cancel()`` method.
See `add_timeout` for comments on thread-safety and subclassing.
.. versionadded:: 4.0
"""
return self.call_at(self.time() + delay, callback, *args, **kwargs)
def call_at(
self, when: float, callback: Callable[..., None], *args: Any, **kwargs: Any
) -> object:
"""Runs the ``callback`` at the absolute time designated by ``when``.
``when`` must be a number using the same reference point as
`IOLoop.time`.
Returns an opaque handle that may be passed to `remove_timeout`
to cancel. Note that unlike the `asyncio` method of the same
name, the returned object does not have a ``cancel()`` method.
See `add_timeout` for comments on thread-safety and subclassing.
.. versionadded:: 4.0
"""
return self.add_timeout(when, callback, *args, **kwargs)
def remove_timeout(self, timeout: object) -> None:
"""Cancels a pending timeout.
The argument is a handle as returned by `add_timeout`. It is
safe to call `remove_timeout` even if the callback has already
been run.
"""
raise NotImplementedError()
def add_callback(self, callback: Callable, *args: Any, **kwargs: Any) -> None:
"""Calls the given callback on the next I/O loop iteration.
It is safe to call this method from any thread at any time,
except from a signal handler. Note that this is the **only**
method in `IOLoop` that makes this thread-safety guarantee; all
other interaction with the `IOLoop` must be done from that
`IOLoop`'s thread. `add_callback()` may be used to transfer
control from other threads to the `IOLoop`'s thread.
To add a callback from a signal handler, see
`add_callback_from_signal`.
"""
raise NotImplementedError()
def add_callback_from_signal(
self, callback: Callable, *args: Any, **kwargs: Any
) -> None:
"""Calls the given callback on the next I/O loop iteration.
Safe for use from a Python signal handler; should not be used
otherwise.
"""
raise NotImplementedError()
def spawn_callback(self, callback: Callable, *args: Any, **kwargs: Any) -> None:
"""Calls the given callback on the next IOLoop iteration.
As of Tornado 6.0, this method is equivalent to `add_callback`.
.. versionadded:: 4.0
"""
self.add_callback(callback, *args, **kwargs)
def add_future(
self,
future: "Union[Future[_T], concurrent.futures.Future[_T]]",
callback: Callable[["Future[_T]"], None],
) -> None:
"""Schedules a callback on the ``IOLoop`` when the given
`.Future` is finished.
The callback is invoked with one argument, the
`.Future`.
This method only accepts `.Future` objects and not other
awaitables (unlike most of Tornado where the two are
interchangeable).
"""
if isinstance(future, Future):
# Note that we specifically do not want the inline behavior of
# tornado.concurrent.future_add_done_callback. We always want
# this callback scheduled on the next IOLoop iteration (which
# asyncio.Future always does).
#
# Wrap the callback in self._run_callback so we control
# the error logging (i.e. it goes to tornado.log.app_log
# instead of asyncio's log).
future.add_done_callback(
lambda f: self._run_callback(functools.partial(callback, future))
)
else:
assert is_future(future)
# For concurrent futures, we use self.add_callback, so
# it's fine if future_add_done_callback inlines that call.
future_add_done_callback(
future, lambda f: self.add_callback(callback, future)
)
def run_in_executor(
self,
executor: Optional[concurrent.futures.Executor],
func: Callable[..., _T],
*args: Any
) -> Awaitable[_T]:
"""Runs a function in a ``concurrent.futures.Executor``. If
``executor`` is ``None``, the IO loop's default executor will be used.
Use `functools.partial` to pass keyword arguments to ``func``.
.. versionadded:: 5.0
"""
if executor is None:
if not hasattr(self, "_executor"):
from tornado.process import cpu_count
self._executor = concurrent.futures.ThreadPoolExecutor(
max_workers=(cpu_count() * 5)
) # type: concurrent.futures.Executor
executor = self._executor
c_future = executor.submit(func, *args)
# Concurrent Futures are not usable with await. Wrap this in a
# Tornado Future instead, using self.add_future for thread-safety.
t_future = Future() # type: Future[_T]
self.add_future(c_future, lambda f: chain_future(f, t_future))
return t_future
def set_default_executor(self, executor: concurrent.futures.Executor) -> None:
"""Sets the default executor to use with :meth:`run_in_executor`.
.. versionadded:: 5.0
"""
self._executor = executor
def _run_callback(self, callback: Callable[[], Any]) -> None:
"""Runs a callback with error handling.
.. versionchanged:: 6.0
CancelledErrors are no longer logged.
"""
try:
ret = callback()
if ret is not None:
from tornado import gen
# Functions that return Futures typically swallow all
# exceptions and store them in the Future. If a Future
# makes it out to the IOLoop, ensure its exception (if any)
# gets logged too.
try:
ret = gen.convert_yielded(ret)
except gen.BadYieldError:
# It's not unusual for add_callback to be used with
# methods returning a non-None and non-yieldable
# result, which should just be ignored.
pass
else:
self.add_future(ret, self._discard_future_result)
except asyncio.CancelledError:
pass
except Exception:
app_log.error("Exception in callback %r", callback, exc_info=True)
def _discard_future_result(self, future: Future) -> None:
"""Avoid unhandled-exception warnings from spawned coroutines."""
future.result()
def split_fd(
self, fd: Union[int, _Selectable]
) -> Tuple[int, Union[int, _Selectable]]:
# """Returns an (fd, obj) pair from an ``fd`` parameter.
# We accept both raw file descriptors and file-like objects as
# input to `add_handler` and related methods. When a file-like
# object is passed, we must retain the object itself so we can
# close it correctly when the `IOLoop` shuts down, but the
# poller interfaces favor file descriptors (they will accept
# file-like objects and call ``fileno()`` for you, but they
# always return the descriptor itself).
# This method is provided for use by `IOLoop` subclasses and should
# not generally be used by application code.
# .. versionadded:: 4.0
# """
if isinstance(fd, int):
return fd, fd
return fd.fileno(), fd
def close_fd(self, fd: Union[int, _Selectable]) -> None:
# """Utility method to close an ``fd``.
# If ``fd`` is a file-like object, we close it directly; otherwise
# we use `os.close`.
# This method is provided for use by `IOLoop` subclasses (in
# implementations of ``IOLoop.close(all_fds=True)`` and should
# not generally be used by application code.
# .. versionadded:: 4.0
# """
try:
if isinstance(fd, int):
os.close(fd)
else:
fd.close()
except OSError:
pass
class _Timeout(object):
"""An IOLoop timeout, a UNIX timestamp and a callback"""
# Reduce memory overhead when there are lots of pending callbacks
__slots__ = ["deadline", "callback", "tdeadline"]
def __init__(
self, deadline: float, callback: Callable[[], None], io_loop: IOLoop
) -> None:
if not isinstance(deadline, numbers.Real):
raise TypeError("Unsupported deadline %r" % deadline)
self.deadline = deadline
self.callback = callback
self.tdeadline = (
deadline,
next(io_loop._timeout_counter),
) # type: Tuple[float, int]
# Comparison methods to sort by deadline, with object id as a tiebreaker
# to guarantee a consistent ordering. The heapq module uses __le__
# in python2.5, and __lt__ in 2.6+ (sort() and most other comparisons
# use __lt__).
def __lt__(self, other: "_Timeout") -> bool:
return self.tdeadline < other.tdeadline
def __le__(self, other: "_Timeout") -> bool:
return self.tdeadline <= other.tdeadline
class PeriodicCallback(object):
"""Schedules the given callback to be called periodically.
The callback is called every ``callback_time`` milliseconds.
Note that the timeout is given in milliseconds, while most other
time-related functions in Tornado use seconds.
If ``jitter`` is specified, each callback time will be randomly selected
within a window of ``jitter * callback_time`` milliseconds.
Jitter can be used to reduce alignment of events with similar periods.
A jitter of 0.1 means allowing a 10% variation in callback time.
The window is centered on ``callback_time`` so the total number of calls
within a given interval should not be significantly affected by adding
jitter.
If the callback runs for longer than ``callback_time`` milliseconds,
subsequent invocations will be skipped to get back on schedule.
`start` must be called after the `PeriodicCallback` is created.
.. versionchanged:: 5.0
The ``io_loop`` argument (deprecated since version 4.1) has been removed.
.. versionchanged:: 5.1
The ``jitter`` argument is added.
"""
def __init__(
self, callback: Callable[[], None], callback_time: float, jitter: float = 0
) -> None:
self.callback = callback
if callback_time <= 0:
raise ValueError("Periodic callback must have a positive callback_time")
self.callback_time = callback_time
self.jitter = jitter
self._running = False
self._timeout = None # type: object
def start(self) -> None:
"""Starts the timer."""
# Looking up the IOLoop here allows to first instantiate the
# PeriodicCallback in another thread, then start it using
# IOLoop.add_callback().
self.io_loop = IOLoop.current()
self._running = True
self._next_timeout = self.io_loop.time()
self._schedule_next()
def stop(self) -> None:
"""Stops the timer."""
self._running = False
if self._timeout is not None:
self.io_loop.remove_timeout(self._timeout)
self._timeout = None
def is_running(self) -> bool:
"""Returns ``True`` if this `.PeriodicCallback` has been started.
.. versionadded:: 4.1
"""
return self._running
def _run(self) -> None:
if not self._running:
return
try:
return self.callback()
except Exception:
app_log.error("Exception in callback %r", self.callback, exc_info=True)
finally:
self._schedule_next()
def _schedule_next(self) -> None:
if self._running:
self._update_next(self.io_loop.time())
self._timeout = self.io_loop.add_timeout(self._next_timeout, self._run)
def _update_next(self, current_time: float) -> None:
callback_time_sec = self.callback_time / 1000.0
if self.jitter:
# apply jitter fraction
callback_time_sec *= 1 + (self.jitter * (random.random() - 0.5))
if self._next_timeout <= current_time:
# The period should be measured from the start of one call
# to the start of the next. If one call takes too long,
# skip cycles to get back to a multiple of the original
# schedule.
self._next_timeout += (
math.floor((current_time - self._next_timeout) / callback_time_sec) + 1
) * callback_time_sec
else:
# If the clock moved backwards, ensure we advance the next
# timeout instead of recomputing the same value again.
# This may result in long gaps between callbacks if the
# clock jumps backwards by a lot, but the far more common
# scenario is a small NTP adjustment that should just be
# ignored.
#
# Note that on some systems if time.time() runs slower
# than time.monotonic() (most common on windows), we
# effectively experience a small backwards time jump on
# every iteration because PeriodicCallback uses
# time.time() while asyncio schedules callbacks using
# time.monotonic().
# https://github.com/tornadoweb/tornado/issues/2333
self._next_timeout += callback_time_sec

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,564 @@
# -*- coding: utf-8 -*-
#
# Copyright 2009 Facebook
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""Translation methods for generating localized strings.
To load a locale and generate a translated string::
user_locale = tornado.locale.get("es_LA")
print(user_locale.translate("Sign out"))
`tornado.locale.get()` returns the closest matching locale, not necessarily the
specific locale you requested. You can support pluralization with
additional arguments to `~Locale.translate()`, e.g.::
people = [...]
message = user_locale.translate(
"%(list)s is online", "%(list)s are online", len(people))
print(message % {"list": user_locale.list(people)})
The first string is chosen if ``len(people) == 1``, otherwise the second
string is chosen.
Applications should call one of `load_translations` (which uses a simple
CSV format) or `load_gettext_translations` (which uses the ``.mo`` format
supported by `gettext` and related tools). If neither method is called,
the `Locale.translate` method will simply return the original string.
"""
import codecs
import csv
import datetime
import gettext
import os
import re
from tornado import escape
from tornado.log import gen_log
from tornado._locale_data import LOCALE_NAMES
from typing import Iterable, Any, Union, Dict
_default_locale = "en_US"
_translations = {} # type: Dict[str, Any]
_supported_locales = frozenset([_default_locale])
_use_gettext = False
CONTEXT_SEPARATOR = "\x04"
def get(*locale_codes: str) -> "Locale":
"""Returns the closest match for the given locale codes.
We iterate over all given locale codes in order. If we have a tight
or a loose match for the code (e.g., "en" for "en_US"), we return
the locale. Otherwise we move to the next code in the list.
By default we return ``en_US`` if no translations are found for any of
the specified locales. You can change the default locale with
`set_default_locale()`.
"""
return Locale.get_closest(*locale_codes)
def set_default_locale(code: str) -> None:
"""Sets the default locale.
The default locale is assumed to be the language used for all strings
in the system. The translations loaded from disk are mappings from
the default locale to the destination locale. Consequently, you don't
need to create a translation file for the default locale.
"""
global _default_locale
global _supported_locales
_default_locale = code
_supported_locales = frozenset(list(_translations.keys()) + [_default_locale])
def load_translations(directory: str, encoding: str = None) -> None:
"""Loads translations from CSV files in a directory.
Translations are strings with optional Python-style named placeholders
(e.g., ``My name is %(name)s``) and their associated translations.
The directory should have translation files of the form ``LOCALE.csv``,
e.g. ``es_GT.csv``. The CSV files should have two or three columns: string,
translation, and an optional plural indicator. Plural indicators should
be one of "plural" or "singular". A given string can have both singular
and plural forms. For example ``%(name)s liked this`` may have a
different verb conjugation depending on whether %(name)s is one
name or a list of names. There should be two rows in the CSV file for
that string, one with plural indicator "singular", and one "plural".
For strings with no verbs that would change on translation, simply
use "unknown" or the empty string (or don't include the column at all).
The file is read using the `csv` module in the default "excel" dialect.
In this format there should not be spaces after the commas.
If no ``encoding`` parameter is given, the encoding will be
detected automatically (among UTF-8 and UTF-16) if the file
contains a byte-order marker (BOM), defaulting to UTF-8 if no BOM
is present.
Example translation ``es_LA.csv``::
"I love you","Te amo"
"%(name)s liked this","A %(name)s les gustó esto","plural"
"%(name)s liked this","A %(name)s le gustó esto","singular"
.. versionchanged:: 4.3
Added ``encoding`` parameter. Added support for BOM-based encoding
detection, UTF-16, and UTF-8-with-BOM.
"""
global _translations
global _supported_locales
_translations = {}
for path in os.listdir(directory):
if not path.endswith(".csv"):
continue
locale, extension = path.split(".")
if not re.match("[a-z]+(_[A-Z]+)?$", locale):
gen_log.error(
"Unrecognized locale %r (path: %s)",
locale,
os.path.join(directory, path),
)
continue
full_path = os.path.join(directory, path)
if encoding is None:
# Try to autodetect encoding based on the BOM.
with open(full_path, "rb") as bf:
data = bf.read(len(codecs.BOM_UTF16_LE))
if data in (codecs.BOM_UTF16_LE, codecs.BOM_UTF16_BE):
encoding = "utf-16"
else:
# utf-8-sig is "utf-8 with optional BOM". It's discouraged
# in most cases but is common with CSV files because Excel
# cannot read utf-8 files without a BOM.
encoding = "utf-8-sig"
# python 3: csv.reader requires a file open in text mode.
# Specify an encoding to avoid dependence on $LANG environment variable.
with open(full_path, encoding=encoding) as f:
_translations[locale] = {}
for i, row in enumerate(csv.reader(f)):
if not row or len(row) < 2:
continue
row = [escape.to_unicode(c).strip() for c in row]
english, translation = row[:2]
if len(row) > 2:
plural = row[2] or "unknown"
else:
plural = "unknown"
if plural not in ("plural", "singular", "unknown"):
gen_log.error(
"Unrecognized plural indicator %r in %s line %d",
plural,
path,
i + 1,
)
continue
_translations[locale].setdefault(plural, {})[english] = translation
_supported_locales = frozenset(list(_translations.keys()) + [_default_locale])
gen_log.debug("Supported locales: %s", sorted(_supported_locales))
def load_gettext_translations(directory: str, domain: str) -> None:
"""Loads translations from `gettext`'s locale tree
Locale tree is similar to system's ``/usr/share/locale``, like::
{directory}/{lang}/LC_MESSAGES/{domain}.mo
Three steps are required to have your app translated:
1. Generate POT translation file::
xgettext --language=Python --keyword=_:1,2 -d mydomain file1.py file2.html etc
2. Merge against existing POT file::
msgmerge old.po mydomain.po > new.po
3. Compile::
msgfmt mydomain.po -o {directory}/pt_BR/LC_MESSAGES/mydomain.mo
"""
import gettext
global _translations
global _supported_locales
global _use_gettext
_translations = {}
for lang in os.listdir(directory):
if lang.startswith("."):
continue # skip .svn, etc
if os.path.isfile(os.path.join(directory, lang)):
continue
try:
os.stat(os.path.join(directory, lang, "LC_MESSAGES", domain + ".mo"))
_translations[lang] = gettext.translation(
domain, directory, languages=[lang]
)
except Exception as e:
gen_log.error("Cannot load translation for '%s': %s", lang, str(e))
continue
_supported_locales = frozenset(list(_translations.keys()) + [_default_locale])
_use_gettext = True
gen_log.debug("Supported locales: %s", sorted(_supported_locales))
def get_supported_locales() -> Iterable[str]:
"""Returns a list of all the supported locale codes."""
return _supported_locales
class Locale(object):
"""Object representing a locale.
After calling one of `load_translations` or `load_gettext_translations`,
call `get` or `get_closest` to get a Locale object.
"""
_cache = {} # type: Dict[str, Locale]
@classmethod
def get_closest(cls, *locale_codes: str) -> "Locale":
"""Returns the closest match for the given locale code."""
for code in locale_codes:
if not code:
continue
code = code.replace("-", "_")
parts = code.split("_")
if len(parts) > 2:
continue
elif len(parts) == 2:
code = parts[0].lower() + "_" + parts[1].upper()
if code in _supported_locales:
return cls.get(code)
if parts[0].lower() in _supported_locales:
return cls.get(parts[0].lower())
return cls.get(_default_locale)
@classmethod
def get(cls, code: str) -> "Locale":
"""Returns the Locale for the given locale code.
If it is not supported, we raise an exception.
"""
if code not in cls._cache:
assert code in _supported_locales
translations = _translations.get(code, None)
if translations is None:
locale = CSVLocale(code, {}) # type: Locale
elif _use_gettext:
locale = GettextLocale(code, translations)
else:
locale = CSVLocale(code, translations)
cls._cache[code] = locale
return cls._cache[code]
def __init__(self, code: str) -> None:
self.code = code
self.name = LOCALE_NAMES.get(code, {}).get("name", u"Unknown")
self.rtl = False
for prefix in ["fa", "ar", "he"]:
if self.code.startswith(prefix):
self.rtl = True
break
# Initialize strings for date formatting
_ = self.translate
self._months = [
_("January"),
_("February"),
_("March"),
_("April"),
_("May"),
_("June"),
_("July"),
_("August"),
_("September"),
_("October"),
_("November"),
_("December"),
]
self._weekdays = [
_("Monday"),
_("Tuesday"),
_("Wednesday"),
_("Thursday"),
_("Friday"),
_("Saturday"),
_("Sunday"),
]
def translate(
self, message: str, plural_message: str = None, count: int = None
) -> str:
"""Returns the translation for the given message for this locale.
If ``plural_message`` is given, you must also provide
``count``. We return ``plural_message`` when ``count != 1``,
and we return the singular form for the given message when
``count == 1``.
"""
raise NotImplementedError()
def pgettext(
self, context: str, message: str, plural_message: str = None, count: int = None
) -> str:
raise NotImplementedError()
def format_date(
self,
date: Union[int, float, datetime.datetime],
gmt_offset: int = 0,
relative: bool = True,
shorter: bool = False,
full_format: bool = False,
) -> str:
"""Formats the given date (which should be GMT).
By default, we return a relative time (e.g., "2 minutes ago"). You
can return an absolute date string with ``relative=False``.
You can force a full format date ("July 10, 1980") with
``full_format=True``.
This method is primarily intended for dates in the past.
For dates in the future, we fall back to full format.
"""
if isinstance(date, (int, float)):
date = datetime.datetime.utcfromtimestamp(date)
now = datetime.datetime.utcnow()
if date > now:
if relative and (date - now).seconds < 60:
# Due to click skew, things are some things slightly
# in the future. Round timestamps in the immediate
# future down to now in relative mode.
date = now
else:
# Otherwise, future dates always use the full format.
full_format = True
local_date = date - datetime.timedelta(minutes=gmt_offset)
local_now = now - datetime.timedelta(minutes=gmt_offset)
local_yesterday = local_now - datetime.timedelta(hours=24)
difference = now - date
seconds = difference.seconds
days = difference.days
_ = self.translate
format = None
if not full_format:
if relative and days == 0:
if seconds < 50:
return _("1 second ago", "%(seconds)d seconds ago", seconds) % {
"seconds": seconds
}
if seconds < 50 * 60:
minutes = round(seconds / 60.0)
return _("1 minute ago", "%(minutes)d minutes ago", minutes) % {
"minutes": minutes
}
hours = round(seconds / (60.0 * 60))
return _("1 hour ago", "%(hours)d hours ago", hours) % {"hours": hours}
if days == 0:
format = _("%(time)s")
elif days == 1 and local_date.day == local_yesterday.day and relative:
format = _("yesterday") if shorter else _("yesterday at %(time)s")
elif days < 5:
format = _("%(weekday)s") if shorter else _("%(weekday)s at %(time)s")
elif days < 334: # 11mo, since confusing for same month last year
format = (
_("%(month_name)s %(day)s")
if shorter
else _("%(month_name)s %(day)s at %(time)s")
)
if format is None:
format = (
_("%(month_name)s %(day)s, %(year)s")
if shorter
else _("%(month_name)s %(day)s, %(year)s at %(time)s")
)
tfhour_clock = self.code not in ("en", "en_US", "zh_CN")
if tfhour_clock:
str_time = "%d:%02d" % (local_date.hour, local_date.minute)
elif self.code == "zh_CN":
str_time = "%s%d:%02d" % (
(u"\u4e0a\u5348", u"\u4e0b\u5348")[local_date.hour >= 12],
local_date.hour % 12 or 12,
local_date.minute,
)
else:
str_time = "%d:%02d %s" % (
local_date.hour % 12 or 12,
local_date.minute,
("am", "pm")[local_date.hour >= 12],
)
return format % {
"month_name": self._months[local_date.month - 1],
"weekday": self._weekdays[local_date.weekday()],
"day": str(local_date.day),
"year": str(local_date.year),
"time": str_time,
}
def format_day(
self, date: datetime.datetime, gmt_offset: int = 0, dow: bool = True
) -> bool:
"""Formats the given date as a day of week.
Example: "Monday, January 22". You can remove the day of week with
``dow=False``.
"""
local_date = date - datetime.timedelta(minutes=gmt_offset)
_ = self.translate
if dow:
return _("%(weekday)s, %(month_name)s %(day)s") % {
"month_name": self._months[local_date.month - 1],
"weekday": self._weekdays[local_date.weekday()],
"day": str(local_date.day),
}
else:
return _("%(month_name)s %(day)s") % {
"month_name": self._months[local_date.month - 1],
"day": str(local_date.day),
}
def list(self, parts: Any) -> str:
"""Returns a comma-separated list for the given list of parts.
The format is, e.g., "A, B and C", "A and B" or just "A" for lists
of size 1.
"""
_ = self.translate
if len(parts) == 0:
return ""
if len(parts) == 1:
return parts[0]
comma = u" \u0648 " if self.code.startswith("fa") else u", "
return _("%(commas)s and %(last)s") % {
"commas": comma.join(parts[:-1]),
"last": parts[len(parts) - 1],
}
def friendly_number(self, value: int) -> str:
"""Returns a comma-separated number for the given integer."""
if self.code not in ("en", "en_US"):
return str(value)
s = str(value)
parts = []
while s:
parts.append(s[-3:])
s = s[:-3]
return ",".join(reversed(parts))
class CSVLocale(Locale):
"""Locale implementation using tornado's CSV translation format."""
def __init__(self, code: str, translations: Dict[str, Dict[str, str]]) -> None:
self.translations = translations
super(CSVLocale, self).__init__(code)
def translate(
self, message: str, plural_message: str = None, count: int = None
) -> str:
if plural_message is not None:
assert count is not None
if count != 1:
message = plural_message
message_dict = self.translations.get("plural", {})
else:
message_dict = self.translations.get("singular", {})
else:
message_dict = self.translations.get("unknown", {})
return message_dict.get(message, message)
def pgettext(
self, context: str, message: str, plural_message: str = None, count: int = None
) -> str:
if self.translations:
gen_log.warning("pgettext is not supported by CSVLocale")
return self.translate(message, plural_message, count)
class GettextLocale(Locale):
"""Locale implementation using the `gettext` module."""
def __init__(self, code: str, translations: gettext.NullTranslations) -> None:
self.ngettext = translations.ngettext
self.gettext = translations.gettext
# self.gettext must exist before __init__ is called, since it
# calls into self.translate
super(GettextLocale, self).__init__(code)
def translate(
self, message: str, plural_message: str = None, count: int = None
) -> str:
if plural_message is not None:
assert count is not None
return self.ngettext(message, plural_message, count)
else:
return self.gettext(message)
def pgettext(
self, context: str, message: str, plural_message: str = None, count: int = None
) -> str:
"""Allows to set context for translation, accepts plural forms.
Usage example::
pgettext("law", "right")
pgettext("good", "right")
Plural message example::
pgettext("organization", "club", "clubs", len(clubs))
pgettext("stick", "club", "clubs", len(clubs))
To generate POT file with context, add following options to step 1
of `load_gettext_translations` sequence::
xgettext [basic options] --keyword=pgettext:1c,2 --keyword=pgettext:1c,2,3
.. versionadded:: 4.2
"""
if plural_message is not None:
assert count is not None
msgs_with_ctxt = (
"%s%s%s" % (context, CONTEXT_SEPARATOR, message),
"%s%s%s" % (context, CONTEXT_SEPARATOR, plural_message),
count,
)
result = self.ngettext(*msgs_with_ctxt)
if CONTEXT_SEPARATOR in result:
# Translation not found
result = self.ngettext(message, plural_message, count)
return result
else:
msg_with_ctxt = "%s%s%s" % (context, CONTEXT_SEPARATOR, message)
result = self.gettext(msg_with_ctxt)
if CONTEXT_SEPARATOR in result:
# Translation not found
result = message
return result

View File

@@ -0,0 +1,570 @@
# Copyright 2015 The Tornado Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import collections
from concurrent.futures import CancelledError
import datetime
import types
from tornado import gen, ioloop
from tornado.concurrent import Future, future_set_result_unless_cancelled
from typing import Union, Optional, Type, Any, Awaitable
import typing
if typing.TYPE_CHECKING:
from typing import Deque, Set # noqa: F401
__all__ = ["Condition", "Event", "Semaphore", "BoundedSemaphore", "Lock"]
class _TimeoutGarbageCollector(object):
"""Base class for objects that periodically clean up timed-out waiters.
Avoids memory leak in a common pattern like:
while True:
yield condition.wait(short_timeout)
print('looping....')
"""
def __init__(self) -> None:
self._waiters = collections.deque() # type: Deque[Future]
self._timeouts = 0
def _garbage_collect(self) -> None:
# Occasionally clear timed-out waiters.
self._timeouts += 1
if self._timeouts > 100:
self._timeouts = 0
self._waiters = collections.deque(w for w in self._waiters if not w.done())
class Condition(_TimeoutGarbageCollector):
"""A condition allows one or more coroutines to wait until notified.
Like a standard `threading.Condition`, but does not need an underlying lock
that is acquired and released.
With a `Condition`, coroutines can wait to be notified by other coroutines:
.. testcode::
from tornado import gen
from tornado.ioloop import IOLoop
from tornado.locks import Condition
condition = Condition()
async def waiter():
print("I'll wait right here")
await condition.wait()
print("I'm done waiting")
async def notifier():
print("About to notify")
condition.notify()
print("Done notifying")
async def runner():
# Wait for waiter() and notifier() in parallel
await gen.multi([waiter(), notifier()])
IOLoop.current().run_sync(runner)
.. testoutput::
I'll wait right here
About to notify
Done notifying
I'm done waiting
`wait` takes an optional ``timeout`` argument, which is either an absolute
timestamp::
io_loop = IOLoop.current()
# Wait up to 1 second for a notification.
await condition.wait(timeout=io_loop.time() + 1)
...or a `datetime.timedelta` for a timeout relative to the current time::
# Wait up to 1 second.
await condition.wait(timeout=datetime.timedelta(seconds=1))
The method returns False if there's no notification before the deadline.
.. versionchanged:: 5.0
Previously, waiters could be notified synchronously from within
`notify`. Now, the notification will always be received on the
next iteration of the `.IOLoop`.
"""
def __init__(self) -> None:
super(Condition, self).__init__()
self.io_loop = ioloop.IOLoop.current()
def __repr__(self) -> str:
result = "<%s" % (self.__class__.__name__,)
if self._waiters:
result += " waiters[%s]" % len(self._waiters)
return result + ">"
def wait(self, timeout: Union[float, datetime.timedelta] = None) -> Awaitable[bool]:
"""Wait for `.notify`.
Returns a `.Future` that resolves ``True`` if the condition is notified,
or ``False`` after a timeout.
"""
waiter = Future() # type: Future[bool]
self._waiters.append(waiter)
if timeout:
def on_timeout() -> None:
if not waiter.done():
future_set_result_unless_cancelled(waiter, False)
self._garbage_collect()
io_loop = ioloop.IOLoop.current()
timeout_handle = io_loop.add_timeout(timeout, on_timeout)
waiter.add_done_callback(lambda _: io_loop.remove_timeout(timeout_handle))
return waiter
def notify(self, n: int = 1) -> None:
"""Wake ``n`` waiters."""
waiters = [] # Waiters we plan to run right now.
while n and self._waiters:
waiter = self._waiters.popleft()
if not waiter.done(): # Might have timed out.
n -= 1
waiters.append(waiter)
for waiter in waiters:
future_set_result_unless_cancelled(waiter, True)
def notify_all(self) -> None:
"""Wake all waiters."""
self.notify(len(self._waiters))
class Event(object):
"""An event blocks coroutines until its internal flag is set to True.
Similar to `threading.Event`.
A coroutine can wait for an event to be set. Once it is set, calls to
``yield event.wait()`` will not block unless the event has been cleared:
.. testcode::
from tornado import gen
from tornado.ioloop import IOLoop
from tornado.locks import Event
event = Event()
async def waiter():
print("Waiting for event")
await event.wait()
print("Not waiting this time")
await event.wait()
print("Done")
async def setter():
print("About to set the event")
event.set()
async def runner():
await gen.multi([waiter(), setter()])
IOLoop.current().run_sync(runner)
.. testoutput::
Waiting for event
About to set the event
Not waiting this time
Done
"""
def __init__(self) -> None:
self._value = False
self._waiters = set() # type: Set[Future[None]]
def __repr__(self) -> str:
return "<%s %s>" % (
self.__class__.__name__,
"set" if self.is_set() else "clear",
)
def is_set(self) -> bool:
"""Return ``True`` if the internal flag is true."""
return self._value
def set(self) -> None:
"""Set the internal flag to ``True``. All waiters are awakened.
Calling `.wait` once the flag is set will not block.
"""
if not self._value:
self._value = True
for fut in self._waiters:
if not fut.done():
fut.set_result(None)
def clear(self) -> None:
"""Reset the internal flag to ``False``.
Calls to `.wait` will block until `.set` is called.
"""
self._value = False
def wait(self, timeout: Union[float, datetime.timedelta] = None) -> Awaitable[None]:
"""Block until the internal flag is true.
Returns an awaitable, which raises `tornado.util.TimeoutError` after a
timeout.
"""
fut = Future() # type: Future[None]
if self._value:
fut.set_result(None)
return fut
self._waiters.add(fut)
fut.add_done_callback(lambda fut: self._waiters.remove(fut))
if timeout is None:
return fut
else:
timeout_fut = gen.with_timeout(
timeout, fut, quiet_exceptions=(CancelledError,)
)
# This is a slightly clumsy workaround for the fact that
# gen.with_timeout doesn't cancel its futures. Cancelling
# fut will remove it from the waiters list.
timeout_fut.add_done_callback(
lambda tf: fut.cancel() if not fut.done() else None
)
return timeout_fut
class _ReleasingContextManager(object):
"""Releases a Lock or Semaphore at the end of a "with" statement.
with (yield semaphore.acquire()):
pass
# Now semaphore.release() has been called.
"""
def __init__(self, obj: Any) -> None:
self._obj = obj
def __enter__(self) -> None:
pass
def __exit__(
self,
exc_type: "Optional[Type[BaseException]]",
exc_val: Optional[BaseException],
exc_tb: Optional[types.TracebackType],
) -> None:
self._obj.release()
class Semaphore(_TimeoutGarbageCollector):
"""A lock that can be acquired a fixed number of times before blocking.
A Semaphore manages a counter representing the number of `.release` calls
minus the number of `.acquire` calls, plus an initial value. The `.acquire`
method blocks if necessary until it can return without making the counter
negative.
Semaphores limit access to a shared resource. To allow access for two
workers at a time:
.. testsetup:: semaphore
from collections import deque
from tornado import gen
from tornado.ioloop import IOLoop
from tornado.concurrent import Future
# Ensure reliable doctest output: resolve Futures one at a time.
futures_q = deque([Future() for _ in range(3)])
async def simulator(futures):
for f in futures:
# simulate the asynchronous passage of time
await gen.sleep(0)
await gen.sleep(0)
f.set_result(None)
IOLoop.current().add_callback(simulator, list(futures_q))
def use_some_resource():
return futures_q.popleft()
.. testcode:: semaphore
from tornado import gen
from tornado.ioloop import IOLoop
from tornado.locks import Semaphore
sem = Semaphore(2)
async def worker(worker_id):
await sem.acquire()
try:
print("Worker %d is working" % worker_id)
await use_some_resource()
finally:
print("Worker %d is done" % worker_id)
sem.release()
async def runner():
# Join all workers.
await gen.multi([worker(i) for i in range(3)])
IOLoop.current().run_sync(runner)
.. testoutput:: semaphore
Worker 0 is working
Worker 1 is working
Worker 0 is done
Worker 2 is working
Worker 1 is done
Worker 2 is done
Workers 0 and 1 are allowed to run concurrently, but worker 2 waits until
the semaphore has been released once, by worker 0.
The semaphore can be used as an async context manager::
async def worker(worker_id):
async with sem:
print("Worker %d is working" % worker_id)
await use_some_resource()
# Now the semaphore has been released.
print("Worker %d is done" % worker_id)
For compatibility with older versions of Python, `.acquire` is a
context manager, so ``worker`` could also be written as::
@gen.coroutine
def worker(worker_id):
with (yield sem.acquire()):
print("Worker %d is working" % worker_id)
yield use_some_resource()
# Now the semaphore has been released.
print("Worker %d is done" % worker_id)
.. versionchanged:: 4.3
Added ``async with`` support in Python 3.5.
"""
def __init__(self, value: int = 1) -> None:
super(Semaphore, self).__init__()
if value < 0:
raise ValueError("semaphore initial value must be >= 0")
self._value = value
def __repr__(self) -> str:
res = super(Semaphore, self).__repr__()
extra = (
"locked" if self._value == 0 else "unlocked,value:{0}".format(self._value)
)
if self._waiters:
extra = "{0},waiters:{1}".format(extra, len(self._waiters))
return "<{0} [{1}]>".format(res[1:-1], extra)
def release(self) -> None:
"""Increment the counter and wake one waiter."""
self._value += 1
while self._waiters:
waiter = self._waiters.popleft()
if not waiter.done():
self._value -= 1
# If the waiter is a coroutine paused at
#
# with (yield semaphore.acquire()):
#
# then the context manager's __exit__ calls release() at the end
# of the "with" block.
waiter.set_result(_ReleasingContextManager(self))
break
def acquire(
self, timeout: Union[float, datetime.timedelta] = None
) -> Awaitable[_ReleasingContextManager]:
"""Decrement the counter. Returns an awaitable.
Block if the counter is zero and wait for a `.release`. The awaitable
raises `.TimeoutError` after the deadline.
"""
waiter = Future() # type: Future[_ReleasingContextManager]
if self._value > 0:
self._value -= 1
waiter.set_result(_ReleasingContextManager(self))
else:
self._waiters.append(waiter)
if timeout:
def on_timeout() -> None:
if not waiter.done():
waiter.set_exception(gen.TimeoutError())
self._garbage_collect()
io_loop = ioloop.IOLoop.current()
timeout_handle = io_loop.add_timeout(timeout, on_timeout)
waiter.add_done_callback(
lambda _: io_loop.remove_timeout(timeout_handle)
)
return waiter
def __enter__(self) -> None:
raise RuntimeError("Use 'async with' instead of 'with' for Semaphore")
def __exit__(
self,
typ: "Optional[Type[BaseException]]",
value: Optional[BaseException],
traceback: Optional[types.TracebackType],
) -> None:
self.__enter__()
async def __aenter__(self) -> None:
await self.acquire()
async def __aexit__(
self,
typ: "Optional[Type[BaseException]]",
value: Optional[BaseException],
tb: Optional[types.TracebackType],
) -> None:
self.release()
class BoundedSemaphore(Semaphore):
"""A semaphore that prevents release() being called too many times.
If `.release` would increment the semaphore's value past the initial
value, it raises `ValueError`. Semaphores are mostly used to guard
resources with limited capacity, so a semaphore released too many times
is a sign of a bug.
"""
def __init__(self, value: int = 1) -> None:
super(BoundedSemaphore, self).__init__(value=value)
self._initial_value = value
def release(self) -> None:
"""Increment the counter and wake one waiter."""
if self._value >= self._initial_value:
raise ValueError("Semaphore released too many times")
super(BoundedSemaphore, self).release()
class Lock(object):
"""A lock for coroutines.
A Lock begins unlocked, and `acquire` locks it immediately. While it is
locked, a coroutine that yields `acquire` waits until another coroutine
calls `release`.
Releasing an unlocked lock raises `RuntimeError`.
A Lock can be used as an async context manager with the ``async
with`` statement:
>>> from tornado import locks
>>> lock = locks.Lock()
>>>
>>> async def f():
... async with lock:
... # Do something holding the lock.
... pass
...
... # Now the lock is released.
For compatibility with older versions of Python, the `.acquire`
method asynchronously returns a regular context manager:
>>> async def f2():
... with (yield lock.acquire()):
... # Do something holding the lock.
... pass
...
... # Now the lock is released.
.. versionchanged:: 4.3
Added ``async with`` support in Python 3.5.
"""
def __init__(self) -> None:
self._block = BoundedSemaphore(value=1)
def __repr__(self) -> str:
return "<%s _block=%s>" % (self.__class__.__name__, self._block)
def acquire(
self, timeout: Union[float, datetime.timedelta] = None
) -> Awaitable[_ReleasingContextManager]:
"""Attempt to lock. Returns an awaitable.
Returns an awaitable, which raises `tornado.util.TimeoutError` after a
timeout.
"""
return self._block.acquire(timeout)
def release(self) -> None:
"""Unlock.
The first coroutine in line waiting for `acquire` gets the lock.
If not locked, raise a `RuntimeError`.
"""
try:
self._block.release()
except ValueError:
raise RuntimeError("release unlocked lock")
def __enter__(self) -> None:
raise RuntimeError("Use `async with` instead of `with` for Lock")
def __exit__(
self,
typ: "Optional[Type[BaseException]]",
value: Optional[BaseException],
tb: Optional[types.TracebackType],
) -> None:
self.__enter__()
async def __aenter__(self) -> None:
await self.acquire()
async def __aexit__(
self,
typ: "Optional[Type[BaseException]]",
value: Optional[BaseException],
tb: Optional[types.TracebackType],
) -> None:
self.release()

View File

@@ -0,0 +1,336 @@
#
# Copyright 2012 Facebook
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""Logging support for Tornado.
Tornado uses three logger streams:
* ``tornado.access``: Per-request logging for Tornado's HTTP servers (and
potentially other servers in the future)
* ``tornado.application``: Logging of errors from application code (i.e.
uncaught exceptions from callbacks)
* ``tornado.general``: General-purpose logging, including any errors
or warnings from Tornado itself.
These streams may be configured independently using the standard library's
`logging` module. For example, you may wish to send ``tornado.access`` logs
to a separate file for analysis.
"""
import logging
import logging.handlers
import sys
from tornado.escape import _unicode
from tornado.util import unicode_type, basestring_type
try:
import colorama # type: ignore
except ImportError:
colorama = None
try:
import curses
except ImportError:
curses = None # type: ignore
from typing import Dict, Any, cast
# Logger objects for internal tornado use
access_log = logging.getLogger("tornado.access")
app_log = logging.getLogger("tornado.application")
gen_log = logging.getLogger("tornado.general")
def _stderr_supports_color() -> bool:
try:
if hasattr(sys.stderr, "isatty") and sys.stderr.isatty():
if curses:
curses.setupterm()
if curses.tigetnum("colors") > 0:
return True
elif colorama:
if sys.stderr is getattr(
colorama.initialise, "wrapped_stderr", object()
):
return True
except Exception:
# Very broad exception handling because it's always better to
# fall back to non-colored logs than to break at startup.
pass
return False
def _safe_unicode(s: Any) -> str:
try:
return _unicode(s)
except UnicodeDecodeError:
return repr(s)
class LogFormatter(logging.Formatter):
"""Log formatter used in Tornado.
Key features of this formatter are:
* Color support when logging to a terminal that supports it.
* Timestamps on every log line.
* Robust against str/bytes encoding problems.
This formatter is enabled automatically by
`tornado.options.parse_command_line` or `tornado.options.parse_config_file`
(unless ``--logging=none`` is used).
Color support on Windows versions that do not support ANSI color codes is
enabled by use of the colorama__ library. Applications that wish to use
this must first initialize colorama with a call to ``colorama.init``.
See the colorama documentation for details.
__ https://pypi.python.org/pypi/colorama
.. versionchanged:: 4.5
Added support for ``colorama``. Changed the constructor
signature to be compatible with `logging.config.dictConfig`.
"""
DEFAULT_FORMAT = "%(color)s[%(levelname)1.1s %(asctime)s %(module)s:%(lineno)d]%(end_color)s %(message)s" # noqa: E501
DEFAULT_DATE_FORMAT = "%y%m%d %H:%M:%S"
DEFAULT_COLORS = {
logging.DEBUG: 4, # Blue
logging.INFO: 2, # Green
logging.WARNING: 3, # Yellow
logging.ERROR: 1, # Red
}
def __init__(
self,
fmt: str = DEFAULT_FORMAT,
datefmt: str = DEFAULT_DATE_FORMAT,
style: str = "%",
color: bool = True,
colors: Dict[int, int] = DEFAULT_COLORS,
) -> None:
r"""
:arg bool color: Enables color support.
:arg str fmt: Log message format.
It will be applied to the attributes dict of log records. The
text between ``%(color)s`` and ``%(end_color)s`` will be colored
depending on the level if color support is on.
:arg dict colors: color mappings from logging level to terminal color
code
:arg str datefmt: Datetime format.
Used for formatting ``(asctime)`` placeholder in ``prefix_fmt``.
.. versionchanged:: 3.2
Added ``fmt`` and ``datefmt`` arguments.
"""
logging.Formatter.__init__(self, datefmt=datefmt)
self._fmt = fmt
self._colors = {} # type: Dict[int, str]
if color and _stderr_supports_color():
if curses is not None:
fg_color = curses.tigetstr("setaf") or curses.tigetstr("setf") or b""
for levelno, code in colors.items():
# Convert the terminal control characters from
# bytes to unicode strings for easier use with the
# logging module.
self._colors[levelno] = unicode_type(
curses.tparm(fg_color, code), "ascii"
)
self._normal = unicode_type(curses.tigetstr("sgr0"), "ascii")
else:
# If curses is not present (currently we'll only get here for
# colorama on windows), assume hard-coded ANSI color codes.
for levelno, code in colors.items():
self._colors[levelno] = "\033[2;3%dm" % code
self._normal = "\033[0m"
else:
self._normal = ""
def format(self, record: Any) -> str:
try:
message = record.getMessage()
assert isinstance(message, basestring_type) # guaranteed by logging
# Encoding notes: The logging module prefers to work with character
# strings, but only enforces that log messages are instances of
# basestring. In python 2, non-ascii bytestrings will make
# their way through the logging framework until they blow up with
# an unhelpful decoding error (with this formatter it happens
# when we attach the prefix, but there are other opportunities for
# exceptions further along in the framework).
#
# If a byte string makes it this far, convert it to unicode to
# ensure it will make it out to the logs. Use repr() as a fallback
# to ensure that all byte strings can be converted successfully,
# but don't do it by default so we don't add extra quotes to ascii
# bytestrings. This is a bit of a hacky place to do this, but
# it's worth it since the encoding errors that would otherwise
# result are so useless (and tornado is fond of using utf8-encoded
# byte strings wherever possible).
record.message = _safe_unicode(message)
except Exception as e:
record.message = "Bad message (%r): %r" % (e, record.__dict__)
record.asctime = self.formatTime(record, cast(str, self.datefmt))
if record.levelno in self._colors:
record.color = self._colors[record.levelno]
record.end_color = self._normal
else:
record.color = record.end_color = ""
formatted = self._fmt % record.__dict__
if record.exc_info:
if not record.exc_text:
record.exc_text = self.formatException(record.exc_info)
if record.exc_text:
# exc_text contains multiple lines. We need to _safe_unicode
# each line separately so that non-utf8 bytes don't cause
# all the newlines to turn into '\n'.
lines = [formatted.rstrip()]
lines.extend(_safe_unicode(ln) for ln in record.exc_text.split("\n"))
formatted = "\n".join(lines)
return formatted.replace("\n", "\n ")
def enable_pretty_logging(options: Any = None, logger: logging.Logger = None) -> None:
"""Turns on formatted logging output as configured.
This is called automatically by `tornado.options.parse_command_line`
and `tornado.options.parse_config_file`.
"""
if options is None:
import tornado.options
options = tornado.options.options
if options.logging is None or options.logging.lower() == "none":
return
if logger is None:
logger = logging.getLogger()
logger.setLevel(getattr(logging, options.logging.upper()))
if options.log_file_prefix:
rotate_mode = options.log_rotate_mode
if rotate_mode == "size":
channel = logging.handlers.RotatingFileHandler(
filename=options.log_file_prefix,
maxBytes=options.log_file_max_size,
backupCount=options.log_file_num_backups,
encoding="utf-8",
) # type: logging.Handler
elif rotate_mode == "time":
channel = logging.handlers.TimedRotatingFileHandler(
filename=options.log_file_prefix,
when=options.log_rotate_when,
interval=options.log_rotate_interval,
backupCount=options.log_file_num_backups,
encoding="utf-8",
)
else:
error_message = (
"The value of log_rotate_mode option should be "
+ '"size" or "time", not "%s".' % rotate_mode
)
raise ValueError(error_message)
channel.setFormatter(LogFormatter(color=False))
logger.addHandler(channel)
if options.log_to_stderr or (options.log_to_stderr is None and not logger.handlers):
# Set up color if we are in a tty and curses is installed
channel = logging.StreamHandler()
channel.setFormatter(LogFormatter())
logger.addHandler(channel)
def define_logging_options(options: Any = None) -> None:
"""Add logging-related flags to ``options``.
These options are present automatically on the default options instance;
this method is only necessary if you have created your own `.OptionParser`.
.. versionadded:: 4.2
This function existed in prior versions but was broken and undocumented until 4.2.
"""
if options is None:
# late import to prevent cycle
import tornado.options
options = tornado.options.options
options.define(
"logging",
default="info",
help=(
"Set the Python log level. If 'none', tornado won't touch the "
"logging configuration."
),
metavar="debug|info|warning|error|none",
)
options.define(
"log_to_stderr",
type=bool,
default=None,
help=(
"Send log output to stderr (colorized if possible). "
"By default use stderr if --log_file_prefix is not set and "
"no other logging is configured."
),
)
options.define(
"log_file_prefix",
type=str,
default=None,
metavar="PATH",
help=(
"Path prefix for log files. "
"Note that if you are running multiple tornado processes, "
"log_file_prefix must be different for each of them (e.g. "
"include the port number)"
),
)
options.define(
"log_file_max_size",
type=int,
default=100 * 1000 * 1000,
help="max size of log files before rollover",
)
options.define(
"log_file_num_backups", type=int, default=10, help="number of log files to keep"
)
options.define(
"log_rotate_when",
type=str,
default="midnight",
help=(
"specify the type of TimedRotatingFileHandler interval "
"other options:('S', 'M', 'H', 'D', 'W0'-'W6')"
),
)
options.define(
"log_rotate_interval",
type=int,
default=1,
help="The interval value of timed rotating",
)
options.define(
"log_rotate_mode",
type=str,
default="size",
help="The mode of rotating files(time or size)",
)
options.add_parse_callback(lambda: enable_pretty_logging(options))

View File

@@ -0,0 +1,614 @@
#
# Copyright 2011 Facebook
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""Miscellaneous network utility code."""
import concurrent.futures
import errno
import os
import sys
import socket
import ssl
import stat
from tornado.concurrent import dummy_executor, run_on_executor
from tornado.ioloop import IOLoop
from tornado.platform.auto import set_close_exec
from tornado.util import Configurable, errno_from_exception
import typing
from typing import List, Callable, Any, Type, Dict, Union, Tuple, Awaitable
if typing.TYPE_CHECKING:
from asyncio import Future # noqa: F401
# Note that the naming of ssl.Purpose is confusing; the purpose
# of a context is to authentiate the opposite side of the connection.
_client_ssl_defaults = ssl.create_default_context(ssl.Purpose.SERVER_AUTH)
_server_ssl_defaults = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
if hasattr(ssl, "OP_NO_COMPRESSION"):
# See netutil.ssl_options_to_context
_client_ssl_defaults.options |= ssl.OP_NO_COMPRESSION
_server_ssl_defaults.options |= ssl.OP_NO_COMPRESSION
# ThreadedResolver runs getaddrinfo on a thread. If the hostname is unicode,
# getaddrinfo attempts to import encodings.idna. If this is done at
# module-import time, the import lock is already held by the main thread,
# leading to deadlock. Avoid it by caching the idna encoder on the main
# thread now.
u"foo".encode("idna")
# For undiagnosed reasons, 'latin1' codec may also need to be preloaded.
u"foo".encode("latin1")
# These errnos indicate that a non-blocking operation must be retried
# at a later time. On most platforms they're the same value, but on
# some they differ.
_ERRNO_WOULDBLOCK = (errno.EWOULDBLOCK, errno.EAGAIN)
if hasattr(errno, "WSAEWOULDBLOCK"):
_ERRNO_WOULDBLOCK += (errno.WSAEWOULDBLOCK,) # type: ignore
# Default backlog used when calling sock.listen()
_DEFAULT_BACKLOG = 128
def bind_sockets(
port: int,
address: str = None,
family: socket.AddressFamily = socket.AF_UNSPEC,
backlog: int = _DEFAULT_BACKLOG,
flags: int = None,
reuse_port: bool = False,
) -> List[socket.socket]:
"""Creates listening sockets bound to the given port and address.
Returns a list of socket objects (multiple sockets are returned if
the given address maps to multiple IP addresses, which is most common
for mixed IPv4 and IPv6 use).
Address may be either an IP address or hostname. If it's a hostname,
the server will listen on all IP addresses associated with the
name. Address may be an empty string or None to listen on all
available interfaces. Family may be set to either `socket.AF_INET`
or `socket.AF_INET6` to restrict to IPv4 or IPv6 addresses, otherwise
both will be used if available.
The ``backlog`` argument has the same meaning as for
`socket.listen() <socket.socket.listen>`.
``flags`` is a bitmask of AI_* flags to `~socket.getaddrinfo`, like
``socket.AI_PASSIVE | socket.AI_NUMERICHOST``.
``reuse_port`` option sets ``SO_REUSEPORT`` option for every socket
in the list. If your platform doesn't support this option ValueError will
be raised.
"""
if reuse_port and not hasattr(socket, "SO_REUSEPORT"):
raise ValueError("the platform doesn't support SO_REUSEPORT")
sockets = []
if address == "":
address = None
if not socket.has_ipv6 and family == socket.AF_UNSPEC:
# Python can be compiled with --disable-ipv6, which causes
# operations on AF_INET6 sockets to fail, but does not
# automatically exclude those results from getaddrinfo
# results.
# http://bugs.python.org/issue16208
family = socket.AF_INET
if flags is None:
flags = socket.AI_PASSIVE
bound_port = None
unique_addresses = set() # type: set
for res in sorted(
socket.getaddrinfo(address, port, family, socket.SOCK_STREAM, 0, flags),
key=lambda x: x[0],
):
if res in unique_addresses:
continue
unique_addresses.add(res)
af, socktype, proto, canonname, sockaddr = res
if (
sys.platform == "darwin"
and address == "localhost"
and af == socket.AF_INET6
and sockaddr[3] != 0
):
# Mac OS X includes a link-local address fe80::1%lo0 in the
# getaddrinfo results for 'localhost'. However, the firewall
# doesn't understand that this is a local address and will
# prompt for access (often repeatedly, due to an apparent
# bug in its ability to remember granting access to an
# application). Skip these addresses.
continue
try:
sock = socket.socket(af, socktype, proto)
except socket.error as e:
if errno_from_exception(e) == errno.EAFNOSUPPORT:
continue
raise
set_close_exec(sock.fileno())
if os.name != "nt":
try:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
except socket.error as e:
if errno_from_exception(e) != errno.ENOPROTOOPT:
# Hurd doesn't support SO_REUSEADDR.
raise
if reuse_port:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
if af == socket.AF_INET6:
# On linux, ipv6 sockets accept ipv4 too by default,
# but this makes it impossible to bind to both
# 0.0.0.0 in ipv4 and :: in ipv6. On other systems,
# separate sockets *must* be used to listen for both ipv4
# and ipv6. For consistency, always disable ipv4 on our
# ipv6 sockets and use a separate ipv4 socket when needed.
#
# Python 2.x on windows doesn't have IPPROTO_IPV6.
if hasattr(socket, "IPPROTO_IPV6"):
sock.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 1)
# automatic port allocation with port=None
# should bind on the same port on IPv4 and IPv6
host, requested_port = sockaddr[:2]
if requested_port == 0 and bound_port is not None:
sockaddr = tuple([host, bound_port] + list(sockaddr[2:]))
sock.setblocking(False)
sock.bind(sockaddr)
bound_port = sock.getsockname()[1]
sock.listen(backlog)
sockets.append(sock)
return sockets
if hasattr(socket, "AF_UNIX"):
def bind_unix_socket(
file: str, mode: int = 0o600, backlog: int = _DEFAULT_BACKLOG
) -> socket.socket:
"""Creates a listening unix socket.
If a socket with the given name already exists, it will be deleted.
If any other file with that name exists, an exception will be
raised.
Returns a socket object (not a list of socket objects like
`bind_sockets`)
"""
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
set_close_exec(sock.fileno())
try:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
except socket.error as e:
if errno_from_exception(e) != errno.ENOPROTOOPT:
# Hurd doesn't support SO_REUSEADDR
raise
sock.setblocking(False)
try:
st = os.stat(file)
except OSError as err:
if errno_from_exception(err) != errno.ENOENT:
raise
else:
if stat.S_ISSOCK(st.st_mode):
os.remove(file)
else:
raise ValueError("File %s exists and is not a socket", file)
sock.bind(file)
os.chmod(file, mode)
sock.listen(backlog)
return sock
def add_accept_handler(
sock: socket.socket, callback: Callable[[socket.socket, Any], None]
) -> Callable[[], None]:
"""Adds an `.IOLoop` event handler to accept new connections on ``sock``.
When a connection is accepted, ``callback(connection, address)`` will
be run (``connection`` is a socket object, and ``address`` is the
address of the other end of the connection). Note that this signature
is different from the ``callback(fd, events)`` signature used for
`.IOLoop` handlers.
A callable is returned which, when called, will remove the `.IOLoop`
event handler and stop processing further incoming connections.
.. versionchanged:: 5.0
The ``io_loop`` argument (deprecated since version 4.1) has been removed.
.. versionchanged:: 5.0
A callable is returned (``None`` was returned before).
"""
io_loop = IOLoop.current()
removed = [False]
def accept_handler(fd: socket.socket, events: int) -> None:
# More connections may come in while we're handling callbacks;
# to prevent starvation of other tasks we must limit the number
# of connections we accept at a time. Ideally we would accept
# up to the number of connections that were waiting when we
# entered this method, but this information is not available
# (and rearranging this method to call accept() as many times
# as possible before running any callbacks would have adverse
# effects on load balancing in multiprocess configurations).
# Instead, we use the (default) listen backlog as a rough
# heuristic for the number of connections we can reasonably
# accept at once.
for i in range(_DEFAULT_BACKLOG):
if removed[0]:
# The socket was probably closed
return
try:
connection, address = sock.accept()
except socket.error as e:
# _ERRNO_WOULDBLOCK indicate we have accepted every
# connection that is available.
if errno_from_exception(e) in _ERRNO_WOULDBLOCK:
return
# ECONNABORTED indicates that there was a connection
# but it was closed while still in the accept queue.
# (observed on FreeBSD).
if errno_from_exception(e) == errno.ECONNABORTED:
continue
raise
set_close_exec(connection.fileno())
callback(connection, address)
def remove_handler() -> None:
io_loop.remove_handler(sock)
removed[0] = True
io_loop.add_handler(sock, accept_handler, IOLoop.READ)
return remove_handler
def is_valid_ip(ip: str) -> bool:
"""Returns ``True`` if the given string is a well-formed IP address.
Supports IPv4 and IPv6.
"""
if not ip or "\x00" in ip:
# getaddrinfo resolves empty strings to localhost, and truncates
# on zero bytes.
return False
try:
res = socket.getaddrinfo(
ip, 0, socket.AF_UNSPEC, socket.SOCK_STREAM, 0, socket.AI_NUMERICHOST
)
return bool(res)
except socket.gaierror as e:
if e.args[0] == socket.EAI_NONAME:
return False
raise
return True
class Resolver(Configurable):
"""Configurable asynchronous DNS resolver interface.
By default, a blocking implementation is used (which simply calls
`socket.getaddrinfo`). An alternative implementation can be
chosen with the `Resolver.configure <.Configurable.configure>`
class method::
Resolver.configure('tornado.netutil.ThreadedResolver')
The implementations of this interface included with Tornado are
* `tornado.netutil.DefaultExecutorResolver`
* `tornado.netutil.BlockingResolver` (deprecated)
* `tornado.netutil.ThreadedResolver` (deprecated)
* `tornado.netutil.OverrideResolver`
* `tornado.platform.twisted.TwistedResolver`
* `tornado.platform.caresresolver.CaresResolver`
.. versionchanged:: 5.0
The default implementation has changed from `BlockingResolver` to
`DefaultExecutorResolver`.
"""
@classmethod
def configurable_base(cls) -> Type["Resolver"]:
return Resolver
@classmethod
def configurable_default(cls) -> Type["Resolver"]:
return DefaultExecutorResolver
def resolve(
self, host: str, port: int, family: socket.AddressFamily = socket.AF_UNSPEC
) -> Awaitable[List[Tuple[int, Any]]]:
"""Resolves an address.
The ``host`` argument is a string which may be a hostname or a
literal IP address.
Returns a `.Future` whose result is a list of (family,
address) pairs, where address is a tuple suitable to pass to
`socket.connect <socket.socket.connect>` (i.e. a ``(host,
port)`` pair for IPv4; additional fields may be present for
IPv6). If a ``callback`` is passed, it will be run with the
result as an argument when it is complete.
:raises IOError: if the address cannot be resolved.
.. versionchanged:: 4.4
Standardized all implementations to raise `IOError`.
.. versionchanged:: 6.0 The ``callback`` argument was removed.
Use the returned awaitable object instead.
"""
raise NotImplementedError()
def close(self) -> None:
"""Closes the `Resolver`, freeing any resources used.
.. versionadded:: 3.1
"""
pass
def _resolve_addr(
host: str, port: int, family: socket.AddressFamily = socket.AF_UNSPEC
) -> List[Tuple[int, Any]]:
# On Solaris, getaddrinfo fails if the given port is not found
# in /etc/services and no socket type is given, so we must pass
# one here. The socket type used here doesn't seem to actually
# matter (we discard the one we get back in the results),
# so the addresses we return should still be usable with SOCK_DGRAM.
addrinfo = socket.getaddrinfo(host, port, family, socket.SOCK_STREAM)
results = []
for fam, socktype, proto, canonname, address in addrinfo:
results.append((fam, address))
return results
class DefaultExecutorResolver(Resolver):
"""Resolver implementation using `.IOLoop.run_in_executor`.
.. versionadded:: 5.0
"""
async def resolve(
self, host: str, port: int, family: socket.AddressFamily = socket.AF_UNSPEC
) -> List[Tuple[int, Any]]:
result = await IOLoop.current().run_in_executor(
None, _resolve_addr, host, port, family
)
return result
class ExecutorResolver(Resolver):
"""Resolver implementation using a `concurrent.futures.Executor`.
Use this instead of `ThreadedResolver` when you require additional
control over the executor being used.
The executor will be shut down when the resolver is closed unless
``close_resolver=False``; use this if you want to reuse the same
executor elsewhere.
.. versionchanged:: 5.0
The ``io_loop`` argument (deprecated since version 4.1) has been removed.
.. deprecated:: 5.0
The default `Resolver` now uses `.IOLoop.run_in_executor`; use that instead
of this class.
"""
def initialize(
self, executor: concurrent.futures.Executor = None, close_executor: bool = True
) -> None:
self.io_loop = IOLoop.current()
if executor is not None:
self.executor = executor
self.close_executor = close_executor
else:
self.executor = dummy_executor
self.close_executor = False
def close(self) -> None:
if self.close_executor:
self.executor.shutdown()
self.executor = None # type: ignore
@run_on_executor
def resolve(
self, host: str, port: int, family: socket.AddressFamily = socket.AF_UNSPEC
) -> List[Tuple[int, Any]]:
return _resolve_addr(host, port, family)
class BlockingResolver(ExecutorResolver):
"""Default `Resolver` implementation, using `socket.getaddrinfo`.
The `.IOLoop` will be blocked during the resolution, although the
callback will not be run until the next `.IOLoop` iteration.
.. deprecated:: 5.0
The default `Resolver` now uses `.IOLoop.run_in_executor`; use that instead
of this class.
"""
def initialize(self) -> None: # type: ignore
super(BlockingResolver, self).initialize()
class ThreadedResolver(ExecutorResolver):
"""Multithreaded non-blocking `Resolver` implementation.
Requires the `concurrent.futures` package to be installed
(available in the standard library since Python 3.2,
installable with ``pip install futures`` in older versions).
The thread pool size can be configured with::
Resolver.configure('tornado.netutil.ThreadedResolver',
num_threads=10)
.. versionchanged:: 3.1
All ``ThreadedResolvers`` share a single thread pool, whose
size is set by the first one to be created.
.. deprecated:: 5.0
The default `Resolver` now uses `.IOLoop.run_in_executor`; use that instead
of this class.
"""
_threadpool = None # type: ignore
_threadpool_pid = None # type: int
def initialize(self, num_threads: int = 10) -> None: # type: ignore
threadpool = ThreadedResolver._create_threadpool(num_threads)
super(ThreadedResolver, self).initialize(
executor=threadpool, close_executor=False
)
@classmethod
def _create_threadpool(
cls, num_threads: int
) -> concurrent.futures.ThreadPoolExecutor:
pid = os.getpid()
if cls._threadpool_pid != pid:
# Threads cannot survive after a fork, so if our pid isn't what it
# was when we created the pool then delete it.
cls._threadpool = None
if cls._threadpool is None:
cls._threadpool = concurrent.futures.ThreadPoolExecutor(num_threads)
cls._threadpool_pid = pid
return cls._threadpool
class OverrideResolver(Resolver):
"""Wraps a resolver with a mapping of overrides.
This can be used to make local DNS changes (e.g. for testing)
without modifying system-wide settings.
The mapping can be in three formats::
{
# Hostname to host or ip
"example.com": "127.0.1.1",
# Host+port to host+port
("login.example.com", 443): ("localhost", 1443),
# Host+port+address family to host+port
("login.example.com", 443, socket.AF_INET6): ("::1", 1443),
}
.. versionchanged:: 5.0
Added support for host-port-family triplets.
"""
def initialize(self, resolver: Resolver, mapping: dict) -> None:
self.resolver = resolver
self.mapping = mapping
def close(self) -> None:
self.resolver.close()
def resolve(
self, host: str, port: int, family: socket.AddressFamily = socket.AF_UNSPEC
) -> Awaitable[List[Tuple[int, Any]]]:
if (host, port, family) in self.mapping:
host, port = self.mapping[(host, port, family)]
elif (host, port) in self.mapping:
host, port = self.mapping[(host, port)]
elif host in self.mapping:
host = self.mapping[host]
return self.resolver.resolve(host, port, family)
# These are the keyword arguments to ssl.wrap_socket that must be translated
# to their SSLContext equivalents (the other arguments are still passed
# to SSLContext.wrap_socket).
_SSL_CONTEXT_KEYWORDS = frozenset(
["ssl_version", "certfile", "keyfile", "cert_reqs", "ca_certs", "ciphers"]
)
def ssl_options_to_context(
ssl_options: Union[Dict[str, Any], ssl.SSLContext]
) -> ssl.SSLContext:
"""Try to convert an ``ssl_options`` dictionary to an
`~ssl.SSLContext` object.
The ``ssl_options`` dictionary contains keywords to be passed to
`ssl.wrap_socket`. In Python 2.7.9+, `ssl.SSLContext` objects can
be used instead. This function converts the dict form to its
`~ssl.SSLContext` equivalent, and may be used when a component which
accepts both forms needs to upgrade to the `~ssl.SSLContext` version
to use features like SNI or NPN.
"""
if isinstance(ssl_options, ssl.SSLContext):
return ssl_options
assert isinstance(ssl_options, dict)
assert all(k in _SSL_CONTEXT_KEYWORDS for k in ssl_options), ssl_options
# Can't use create_default_context since this interface doesn't
# tell us client vs server.
context = ssl.SSLContext(ssl_options.get("ssl_version", ssl.PROTOCOL_SSLv23))
if "certfile" in ssl_options:
context.load_cert_chain(
ssl_options["certfile"], ssl_options.get("keyfile", None)
)
if "cert_reqs" in ssl_options:
context.verify_mode = ssl_options["cert_reqs"]
if "ca_certs" in ssl_options:
context.load_verify_locations(ssl_options["ca_certs"])
if "ciphers" in ssl_options:
context.set_ciphers(ssl_options["ciphers"])
if hasattr(ssl, "OP_NO_COMPRESSION"):
# Disable TLS compression to avoid CRIME and related attacks.
# This constant depends on openssl version 1.0.
# TODO: Do we need to do this ourselves or can we trust
# the defaults?
context.options |= ssl.OP_NO_COMPRESSION
return context
def ssl_wrap_socket(
socket: socket.socket,
ssl_options: Union[Dict[str, Any], ssl.SSLContext],
server_hostname: str = None,
**kwargs: Any
) -> ssl.SSLSocket:
"""Returns an ``ssl.SSLSocket`` wrapping the given socket.
``ssl_options`` may be either an `ssl.SSLContext` object or a
dictionary (as accepted by `ssl_options_to_context`). Additional
keyword arguments are passed to ``wrap_socket`` (either the
`~ssl.SSLContext` method or the `ssl` module function as
appropriate).
"""
context = ssl_options_to_context(ssl_options)
if ssl.HAS_SNI:
# In python 3.4, wrap_socket only accepts the server_hostname
# argument if HAS_SNI is true.
# TODO: add a unittest (python added server-side SNI support in 3.4)
# In the meantime it can be manually tested with
# python3 -m tornado.httpclient https://sni.velox.ch
return context.wrap_socket(socket, server_hostname=server_hostname, **kwargs)
else:
return context.wrap_socket(socket, **kwargs)

View File

@@ -0,0 +1,726 @@
#
# Copyright 2009 Facebook
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""A command line parsing module that lets modules define their own options.
This module is inspired by Google's `gflags
<https://github.com/google/python-gflags>`_. The primary difference
with libraries such as `argparse` is that a global registry is used so
that options may be defined in any module (it also enables
`tornado.log` by default). The rest of Tornado does not depend on this
module, so feel free to use `argparse` or other configuration
libraries if you prefer them.
Options must be defined with `tornado.options.define` before use,
generally at the top level of a module. The options are then
accessible as attributes of `tornado.options.options`::
# myapp/db.py
from tornado.options import define, options
define("mysql_host", default="127.0.0.1:3306", help="Main user DB")
define("memcache_hosts", default="127.0.0.1:11011", multiple=True,
help="Main user memcache servers")
def connect():
db = database.Connection(options.mysql_host)
...
# myapp/server.py
from tornado.options import define, options
define("port", default=8080, help="port to listen on")
def start_server():
app = make_app()
app.listen(options.port)
The ``main()`` method of your application does not need to be aware of all of
the options used throughout your program; they are all automatically loaded
when the modules are loaded. However, all modules that define options
must have been imported before the command line is parsed.
Your ``main()`` method can parse the command line or parse a config file with
either `parse_command_line` or `parse_config_file`::
import myapp.db, myapp.server
import tornado.options
if __name__ == '__main__':
tornado.options.parse_command_line()
# or
tornado.options.parse_config_file("/etc/server.conf")
.. note::
When using multiple ``parse_*`` functions, pass ``final=False`` to all
but the last one, or side effects may occur twice (in particular,
this can result in log messages being doubled).
`tornado.options.options` is a singleton instance of `OptionParser`, and
the top-level functions in this module (`define`, `parse_command_line`, etc)
simply call methods on it. You may create additional `OptionParser`
instances to define isolated sets of options, such as for subcommands.
.. note::
By default, several options are defined that will configure the
standard `logging` module when `parse_command_line` or `parse_config_file`
are called. If you want Tornado to leave the logging configuration
alone so you can manage it yourself, either pass ``--logging=none``
on the command line or do the following to disable it in code::
from tornado.options import options, parse_command_line
options.logging = None
parse_command_line()
.. versionchanged:: 4.3
Dashes and underscores are fully interchangeable in option names;
options can be defined, set, and read with any mix of the two.
Dashes are typical for command-line usage while config files require
underscores.
"""
import datetime
import numbers
import re
import sys
import os
import textwrap
from tornado.escape import _unicode, native_str
from tornado.log import define_logging_options
from tornado.util import basestring_type, exec_in
import typing
from typing import Any, Iterator, Iterable, Tuple, Set, Dict, Callable, List, TextIO
if typing.TYPE_CHECKING:
from typing import Optional # noqa: F401
class Error(Exception):
"""Exception raised by errors in the options module."""
pass
class OptionParser(object):
"""A collection of options, a dictionary with object-like access.
Normally accessed via static functions in the `tornado.options` module,
which reference a global instance.
"""
def __init__(self) -> None:
# we have to use self.__dict__ because we override setattr.
self.__dict__["_options"] = {}
self.__dict__["_parse_callbacks"] = []
self.define(
"help",
type=bool,
help="show this help information",
callback=self._help_callback,
)
def _normalize_name(self, name: str) -> str:
return name.replace("_", "-")
def __getattr__(self, name: str) -> Any:
name = self._normalize_name(name)
if isinstance(self._options.get(name), _Option):
return self._options[name].value()
raise AttributeError("Unrecognized option %r" % name)
def __setattr__(self, name: str, value: Any) -> None:
name = self._normalize_name(name)
if isinstance(self._options.get(name), _Option):
return self._options[name].set(value)
raise AttributeError("Unrecognized option %r" % name)
def __iter__(self) -> Iterator:
return (opt.name for opt in self._options.values())
def __contains__(self, name: str) -> bool:
name = self._normalize_name(name)
return name in self._options
def __getitem__(self, name: str) -> Any:
return self.__getattr__(name)
def __setitem__(self, name: str, value: Any) -> None:
return self.__setattr__(name, value)
def items(self) -> Iterable[Tuple[str, Any]]:
"""An iterable of (name, value) pairs.
.. versionadded:: 3.1
"""
return [(opt.name, opt.value()) for name, opt in self._options.items()]
def groups(self) -> Set[str]:
"""The set of option-groups created by ``define``.
.. versionadded:: 3.1
"""
return set(opt.group_name for opt in self._options.values())
def group_dict(self, group: str) -> Dict[str, Any]:
"""The names and values of options in a group.
Useful for copying options into Application settings::
from tornado.options import define, parse_command_line, options
define('template_path', group='application')
define('static_path', group='application')
parse_command_line()
application = Application(
handlers, **options.group_dict('application'))
.. versionadded:: 3.1
"""
return dict(
(opt.name, opt.value())
for name, opt in self._options.items()
if not group or group == opt.group_name
)
def as_dict(self) -> Dict[str, Any]:
"""The names and values of all options.
.. versionadded:: 3.1
"""
return dict((opt.name, opt.value()) for name, opt in self._options.items())
def define(
self,
name: str,
default: Any = None,
type: type = None,
help: str = None,
metavar: str = None,
multiple: bool = False,
group: str = None,
callback: Callable[[Any], None] = None,
) -> None:
"""Defines a new command line option.
``type`` can be any of `str`, `int`, `float`, `bool`,
`~datetime.datetime`, or `~datetime.timedelta`. If no ``type``
is given but a ``default`` is, ``type`` is the type of
``default``. Otherwise, ``type`` defaults to `str`.
If ``multiple`` is True, the option value is a list of ``type``
instead of an instance of ``type``.
``help`` and ``metavar`` are used to construct the
automatically generated command line help string. The help
message is formatted like::
--name=METAVAR help string
``group`` is used to group the defined options in logical
groups. By default, command line options are grouped by the
file in which they are defined.
Command line option names must be unique globally.
If a ``callback`` is given, it will be run with the new value whenever
the option is changed. This can be used to combine command-line
and file-based options::
define("config", type=str, help="path to config file",
callback=lambda path: parse_config_file(path, final=False))
With this definition, options in the file specified by ``--config`` will
override options set earlier on the command line, but can be overridden
by later flags.
"""
normalized = self._normalize_name(name)
if normalized in self._options:
raise Error(
"Option %r already defined in %s"
% (normalized, self._options[normalized].file_name)
)
frame = sys._getframe(0)
options_file = frame.f_code.co_filename
# Can be called directly, or through top level define() fn, in which
# case, step up above that frame to look for real caller.
if (
frame.f_back.f_code.co_filename == options_file
and frame.f_back.f_code.co_name == "define"
):
frame = frame.f_back
file_name = frame.f_back.f_code.co_filename
if file_name == options_file:
file_name = ""
if type is None:
if not multiple and default is not None:
type = default.__class__
else:
type = str
if group:
group_name = group # type: Optional[str]
else:
group_name = file_name
option = _Option(
name,
file_name=file_name,
default=default,
type=type,
help=help,
metavar=metavar,
multiple=multiple,
group_name=group_name,
callback=callback,
)
self._options[normalized] = option
def parse_command_line(
self, args: List[str] = None, final: bool = True
) -> List[str]:
"""Parses all options given on the command line (defaults to
`sys.argv`).
Options look like ``--option=value`` and are parsed according
to their ``type``. For boolean options, ``--option`` is
equivalent to ``--option=true``
If the option has ``multiple=True``, comma-separated values
are accepted. For multi-value integer options, the syntax
``x:y`` is also accepted and equivalent to ``range(x, y)``.
Note that ``args[0]`` is ignored since it is the program name
in `sys.argv`.
We return a list of all arguments that are not parsed as options.
If ``final`` is ``False``, parse callbacks will not be run.
This is useful for applications that wish to combine configurations
from multiple sources.
"""
if args is None:
args = sys.argv
remaining = [] # type: List[str]
for i in range(1, len(args)):
# All things after the last option are command line arguments
if not args[i].startswith("-"):
remaining = args[i:]
break
if args[i] == "--":
remaining = args[i + 1 :]
break
arg = args[i].lstrip("-")
name, equals, value = arg.partition("=")
name = self._normalize_name(name)
if name not in self._options:
self.print_help()
raise Error("Unrecognized command line option: %r" % name)
option = self._options[name]
if not equals:
if option.type == bool:
value = "true"
else:
raise Error("Option %r requires a value" % name)
option.parse(value)
if final:
self.run_parse_callbacks()
return remaining
def parse_config_file(self, path: str, final: bool = True) -> None:
"""Parses and loads the config file at the given path.
The config file contains Python code that will be executed (so
it is **not safe** to use untrusted config files). Anything in
the global namespace that matches a defined option will be
used to set that option's value.
Options may either be the specified type for the option or
strings (in which case they will be parsed the same way as in
`.parse_command_line`)
Example (using the options defined in the top-level docs of
this module)::
port = 80
mysql_host = 'mydb.example.com:3306'
# Both lists and comma-separated strings are allowed for
# multiple=True.
memcache_hosts = ['cache1.example.com:11011',
'cache2.example.com:11011']
memcache_hosts = 'cache1.example.com:11011,cache2.example.com:11011'
If ``final`` is ``False``, parse callbacks will not be run.
This is useful for applications that wish to combine configurations
from multiple sources.
.. note::
`tornado.options` is primarily a command-line library.
Config file support is provided for applications that wish
to use it, but applications that prefer config files may
wish to look at other libraries instead.
.. versionchanged:: 4.1
Config files are now always interpreted as utf-8 instead of
the system default encoding.
.. versionchanged:: 4.4
The special variable ``__file__`` is available inside config
files, specifying the absolute path to the config file itself.
.. versionchanged:: 5.1
Added the ability to set options via strings in config files.
"""
config = {"__file__": os.path.abspath(path)}
with open(path, "rb") as f:
exec_in(native_str(f.read()), config, config)
for name in config:
normalized = self._normalize_name(name)
if normalized in self._options:
option = self._options[normalized]
if option.multiple:
if not isinstance(config[name], (list, str)):
raise Error(
"Option %r is required to be a list of %s "
"or a comma-separated string"
% (option.name, option.type.__name__)
)
if type(config[name]) == str and option.type != str:
option.parse(config[name])
else:
option.set(config[name])
if final:
self.run_parse_callbacks()
def print_help(self, file: TextIO = None) -> None:
"""Prints all the command line options to stderr (or another file)."""
if file is None:
file = sys.stderr
print("Usage: %s [OPTIONS]" % sys.argv[0], file=file)
print("\nOptions:\n", file=file)
by_group = {} # type: Dict[str, List[_Option]]
for option in self._options.values():
by_group.setdefault(option.group_name, []).append(option)
for filename, o in sorted(by_group.items()):
if filename:
print("\n%s options:\n" % os.path.normpath(filename), file=file)
o.sort(key=lambda option: option.name)
for option in o:
# Always print names with dashes in a CLI context.
prefix = self._normalize_name(option.name)
if option.metavar:
prefix += "=" + option.metavar
description = option.help or ""
if option.default is not None and option.default != "":
description += " (default %s)" % option.default
lines = textwrap.wrap(description, 79 - 35)
if len(prefix) > 30 or len(lines) == 0:
lines.insert(0, "")
print(" --%-30s %s" % (prefix, lines[0]), file=file)
for line in lines[1:]:
print("%-34s %s" % (" ", line), file=file)
print(file=file)
def _help_callback(self, value: bool) -> None:
if value:
self.print_help()
sys.exit(0)
def add_parse_callback(self, callback: Callable[[], None]) -> None:
"""Adds a parse callback, to be invoked when option parsing is done."""
self._parse_callbacks.append(callback)
def run_parse_callbacks(self) -> None:
for callback in self._parse_callbacks:
callback()
def mockable(self) -> "_Mockable":
"""Returns a wrapper around self that is compatible with
`mock.patch <unittest.mock.patch>`.
The `mock.patch <unittest.mock.patch>` function (included in
the standard library `unittest.mock` package since Python 3.3,
or in the third-party ``mock`` package for older versions of
Python) is incompatible with objects like ``options`` that
override ``__getattr__`` and ``__setattr__``. This function
returns an object that can be used with `mock.patch.object
<unittest.mock.patch.object>` to modify option values::
with mock.patch.object(options.mockable(), 'name', value):
assert options.name == value
"""
return _Mockable(self)
class _Mockable(object):
"""`mock.patch` compatible wrapper for `OptionParser`.
As of ``mock`` version 1.0.1, when an object uses ``__getattr__``
hooks instead of ``__dict__``, ``patch.__exit__`` tries to delete
the attribute it set instead of setting a new one (assuming that
the object does not catpure ``__setattr__``, so the patch
created a new attribute in ``__dict__``).
_Mockable's getattr and setattr pass through to the underlying
OptionParser, and delattr undoes the effect of a previous setattr.
"""
def __init__(self, options: OptionParser) -> None:
# Modify __dict__ directly to bypass __setattr__
self.__dict__["_options"] = options
self.__dict__["_originals"] = {}
def __getattr__(self, name: str) -> Any:
return getattr(self._options, name)
def __setattr__(self, name: str, value: Any) -> None:
assert name not in self._originals, "don't reuse mockable objects"
self._originals[name] = getattr(self._options, name)
setattr(self._options, name, value)
def __delattr__(self, name: str) -> None:
setattr(self._options, name, self._originals.pop(name))
class _Option(object):
# This class could almost be made generic, but the way the types
# interact with the multiple argument makes this tricky. (default
# and the callback use List[T], but type is still Type[T]).
UNSET = object()
def __init__(
self,
name: str,
default: Any = None,
type: type = None,
help: str = None,
metavar: str = None,
multiple: bool = False,
file_name: str = None,
group_name: str = None,
callback: Callable[[Any], None] = None,
) -> None:
if default is None and multiple:
default = []
self.name = name
if type is None:
raise ValueError("type must not be None")
self.type = type
self.help = help
self.metavar = metavar
self.multiple = multiple
self.file_name = file_name
self.group_name = group_name
self.callback = callback
self.default = default
self._value = _Option.UNSET # type: Any
def value(self) -> Any:
return self.default if self._value is _Option.UNSET else self._value
def parse(self, value: str) -> Any:
_parse = {
datetime.datetime: self._parse_datetime,
datetime.timedelta: self._parse_timedelta,
bool: self._parse_bool,
basestring_type: self._parse_string,
}.get(
self.type, self.type
) # type: Callable[[str], Any]
if self.multiple:
self._value = []
for part in value.split(","):
if issubclass(self.type, numbers.Integral):
# allow ranges of the form X:Y (inclusive at both ends)
lo_str, _, hi_str = part.partition(":")
lo = _parse(lo_str)
hi = _parse(hi_str) if hi_str else lo
self._value.extend(range(lo, hi + 1))
else:
self._value.append(_parse(part))
else:
self._value = _parse(value)
if self.callback is not None:
self.callback(self._value)
return self.value()
def set(self, value: Any) -> None:
if self.multiple:
if not isinstance(value, list):
raise Error(
"Option %r is required to be a list of %s"
% (self.name, self.type.__name__)
)
for item in value:
if item is not None and not isinstance(item, self.type):
raise Error(
"Option %r is required to be a list of %s"
% (self.name, self.type.__name__)
)
else:
if value is not None and not isinstance(value, self.type):
raise Error(
"Option %r is required to be a %s (%s given)"
% (self.name, self.type.__name__, type(value))
)
self._value = value
if self.callback is not None:
self.callback(self._value)
# Supported date/time formats in our options
_DATETIME_FORMATS = [
"%a %b %d %H:%M:%S %Y",
"%Y-%m-%d %H:%M:%S",
"%Y-%m-%d %H:%M",
"%Y-%m-%dT%H:%M",
"%Y%m%d %H:%M:%S",
"%Y%m%d %H:%M",
"%Y-%m-%d",
"%Y%m%d",
"%H:%M:%S",
"%H:%M",
]
def _parse_datetime(self, value: str) -> datetime.datetime:
for format in self._DATETIME_FORMATS:
try:
return datetime.datetime.strptime(value, format)
except ValueError:
pass
raise Error("Unrecognized date/time format: %r" % value)
_TIMEDELTA_ABBREV_DICT = {
"h": "hours",
"m": "minutes",
"min": "minutes",
"s": "seconds",
"sec": "seconds",
"ms": "milliseconds",
"us": "microseconds",
"d": "days",
"w": "weeks",
}
_FLOAT_PATTERN = r"[-+]?(?:\d+(?:\.\d*)?|\.\d+)(?:[eE][-+]?\d+)?"
_TIMEDELTA_PATTERN = re.compile(
r"\s*(%s)\s*(\w*)\s*" % _FLOAT_PATTERN, re.IGNORECASE
)
def _parse_timedelta(self, value: str) -> datetime.timedelta:
try:
sum = datetime.timedelta()
start = 0
while start < len(value):
m = self._TIMEDELTA_PATTERN.match(value, start)
if not m:
raise Exception()
num = float(m.group(1))
units = m.group(2) or "seconds"
units = self._TIMEDELTA_ABBREV_DICT.get(units, units)
sum += datetime.timedelta(**{units: num})
start = m.end()
return sum
except Exception:
raise
def _parse_bool(self, value: str) -> bool:
return value.lower() not in ("false", "0", "f")
def _parse_string(self, value: str) -> str:
return _unicode(value)
options = OptionParser()
"""Global options object.
All defined options are available as attributes on this object.
"""
def define(
name: str,
default: Any = None,
type: type = None,
help: str = None,
metavar: str = None,
multiple: bool = False,
group: str = None,
callback: Callable[[Any], None] = None,
) -> None:
"""Defines an option in the global namespace.
See `OptionParser.define`.
"""
return options.define(
name,
default=default,
type=type,
help=help,
metavar=metavar,
multiple=multiple,
group=group,
callback=callback,
)
def parse_command_line(args: List[str] = None, final: bool = True) -> List[str]:
"""Parses global options from the command line.
See `OptionParser.parse_command_line`.
"""
return options.parse_command_line(args, final=final)
def parse_config_file(path: str, final: bool = True) -> None:
"""Parses global options from a config file.
See `OptionParser.parse_config_file`.
"""
return options.parse_config_file(path, final=final)
def print_help(file: TextIO = None) -> None:
"""Prints all the command line options to stderr (or another file).
See `OptionParser.print_help`.
"""
return options.print_help(file)
def add_parse_callback(callback: Callable[[], None]) -> None:
"""Adds a parse callback, to be invoked when option parsing is done.
See `OptionParser.add_parse_callback`
"""
options.add_parse_callback(callback)
# Default options
define_logging_options(options)

View File

@@ -0,0 +1,337 @@
"""Bridges between the `asyncio` module and Tornado IOLoop.
.. versionadded:: 3.2
This module integrates Tornado with the ``asyncio`` module introduced
in Python 3.4. This makes it possible to combine the two libraries on
the same event loop.
.. deprecated:: 5.0
While the code in this module is still used, it is now enabled
automatically when `asyncio` is available, so applications should
no longer need to refer to this module directly.
.. note::
Tornado requires the `~asyncio.AbstractEventLoop.add_reader` family of
methods, so it is not compatible with the `~asyncio.ProactorEventLoop` on
Windows. Use the `~asyncio.SelectorEventLoop` instead.
"""
import concurrent.futures
import functools
from threading import get_ident
from tornado.gen import convert_yielded
from tornado.ioloop import IOLoop, _Selectable
import asyncio
import typing
from typing import Any, TypeVar, Awaitable, Callable, Union, Optional
if typing.TYPE_CHECKING:
from typing import Set, Dict, Tuple # noqa: F401
_T = TypeVar("_T")
class BaseAsyncIOLoop(IOLoop):
def initialize( # type: ignore
self, asyncio_loop: asyncio.AbstractEventLoop, **kwargs: Any
) -> None:
self.asyncio_loop = asyncio_loop
# Maps fd to (fileobj, handler function) pair (as in IOLoop.add_handler)
self.handlers = {} # type: Dict[int, Tuple[Union[int, _Selectable], Callable]]
# Set of fds listening for reads/writes
self.readers = set() # type: Set[int]
self.writers = set() # type: Set[int]
self.closing = False
# If an asyncio loop was closed through an asyncio interface
# instead of IOLoop.close(), we'd never hear about it and may
# have left a dangling reference in our map. In case an
# application (or, more likely, a test suite) creates and
# destroys a lot of event loops in this way, check here to
# ensure that we don't have a lot of dead loops building up in
# the map.
#
# TODO(bdarnell): consider making self.asyncio_loop a weakref
# for AsyncIOMainLoop and make _ioloop_for_asyncio a
# WeakKeyDictionary.
for loop in list(IOLoop._ioloop_for_asyncio):
if loop.is_closed():
del IOLoop._ioloop_for_asyncio[loop]
IOLoop._ioloop_for_asyncio[asyncio_loop] = self
self._thread_identity = 0
super(BaseAsyncIOLoop, self).initialize(**kwargs)
def assign_thread_identity() -> None:
self._thread_identity = get_ident()
self.add_callback(assign_thread_identity)
def close(self, all_fds: bool = False) -> None:
self.closing = True
for fd in list(self.handlers):
fileobj, handler_func = self.handlers[fd]
self.remove_handler(fd)
if all_fds:
self.close_fd(fileobj)
# Remove the mapping before closing the asyncio loop. If this
# happened in the other order, we could race against another
# initialize() call which would see the closed asyncio loop,
# assume it was closed from the asyncio side, and do this
# cleanup for us, leading to a KeyError.
del IOLoop._ioloop_for_asyncio[self.asyncio_loop]
self.asyncio_loop.close()
def add_handler(
self, fd: Union[int, _Selectable], handler: Callable[..., None], events: int
) -> None:
fd, fileobj = self.split_fd(fd)
if fd in self.handlers:
raise ValueError("fd %s added twice" % fd)
self.handlers[fd] = (fileobj, handler)
if events & IOLoop.READ:
self.asyncio_loop.add_reader(fd, self._handle_events, fd, IOLoop.READ)
self.readers.add(fd)
if events & IOLoop.WRITE:
self.asyncio_loop.add_writer(fd, self._handle_events, fd, IOLoop.WRITE)
self.writers.add(fd)
def update_handler(self, fd: Union[int, _Selectable], events: int) -> None:
fd, fileobj = self.split_fd(fd)
if events & IOLoop.READ:
if fd not in self.readers:
self.asyncio_loop.add_reader(fd, self._handle_events, fd, IOLoop.READ)
self.readers.add(fd)
else:
if fd in self.readers:
self.asyncio_loop.remove_reader(fd)
self.readers.remove(fd)
if events & IOLoop.WRITE:
if fd not in self.writers:
self.asyncio_loop.add_writer(fd, self._handle_events, fd, IOLoop.WRITE)
self.writers.add(fd)
else:
if fd in self.writers:
self.asyncio_loop.remove_writer(fd)
self.writers.remove(fd)
def remove_handler(self, fd: Union[int, _Selectable]) -> None:
fd, fileobj = self.split_fd(fd)
if fd not in self.handlers:
return
if fd in self.readers:
self.asyncio_loop.remove_reader(fd)
self.readers.remove(fd)
if fd in self.writers:
self.asyncio_loop.remove_writer(fd)
self.writers.remove(fd)
del self.handlers[fd]
def _handle_events(self, fd: int, events: int) -> None:
fileobj, handler_func = self.handlers[fd]
handler_func(fileobj, events)
def start(self) -> None:
try:
old_loop = asyncio.get_event_loop()
except (RuntimeError, AssertionError):
old_loop = None # type: ignore
try:
self._setup_logging()
asyncio.set_event_loop(self.asyncio_loop)
self.asyncio_loop.run_forever()
finally:
asyncio.set_event_loop(old_loop)
def stop(self) -> None:
self.asyncio_loop.stop()
def call_at(
self, when: float, callback: Callable[..., None], *args: Any, **kwargs: Any
) -> object:
# asyncio.call_at supports *args but not **kwargs, so bind them here.
# We do not synchronize self.time and asyncio_loop.time, so
# convert from absolute to relative.
return self.asyncio_loop.call_later(
max(0, when - self.time()),
self._run_callback,
functools.partial(callback, *args, **kwargs),
)
def remove_timeout(self, timeout: object) -> None:
timeout.cancel() # type: ignore
def add_callback(self, callback: Callable, *args: Any, **kwargs: Any) -> None:
if get_ident() == self._thread_identity:
call_soon = self.asyncio_loop.call_soon
else:
call_soon = self.asyncio_loop.call_soon_threadsafe
try:
call_soon(self._run_callback, functools.partial(callback, *args, **kwargs))
except RuntimeError:
# "Event loop is closed". Swallow the exception for
# consistency with PollIOLoop (and logical consistency
# with the fact that we can't guarantee that an
# add_callback that completes without error will
# eventually execute).
pass
def add_callback_from_signal(
self, callback: Callable, *args: Any, **kwargs: Any
) -> None:
try:
self.asyncio_loop.call_soon_threadsafe(
self._run_callback, functools.partial(callback, *args, **kwargs)
)
except RuntimeError:
pass
def run_in_executor(
self,
executor: Optional[concurrent.futures.Executor],
func: Callable[..., _T],
*args: Any
) -> Awaitable[_T]:
return self.asyncio_loop.run_in_executor(executor, func, *args)
def set_default_executor(self, executor: concurrent.futures.Executor) -> None:
return self.asyncio_loop.set_default_executor(executor)
class AsyncIOMainLoop(BaseAsyncIOLoop):
"""``AsyncIOMainLoop`` creates an `.IOLoop` that corresponds to the
current ``asyncio`` event loop (i.e. the one returned by
``asyncio.get_event_loop()``).
.. deprecated:: 5.0
Now used automatically when appropriate; it is no longer necessary
to refer to this class directly.
.. versionchanged:: 5.0
Closing an `AsyncIOMainLoop` now closes the underlying asyncio loop.
"""
def initialize(self, **kwargs: Any) -> None: # type: ignore
super(AsyncIOMainLoop, self).initialize(asyncio.get_event_loop(), **kwargs)
def make_current(self) -> None:
# AsyncIOMainLoop already refers to the current asyncio loop so
# nothing to do here.
pass
class AsyncIOLoop(BaseAsyncIOLoop):
"""``AsyncIOLoop`` is an `.IOLoop` that runs on an ``asyncio`` event loop.
This class follows the usual Tornado semantics for creating new
``IOLoops``; these loops are not necessarily related to the
``asyncio`` default event loop.
Each ``AsyncIOLoop`` creates a new ``asyncio.EventLoop``; this object
can be accessed with the ``asyncio_loop`` attribute.
.. versionchanged:: 5.0
When an ``AsyncIOLoop`` becomes the current `.IOLoop`, it also sets
the current `asyncio` event loop.
.. deprecated:: 5.0
Now used automatically when appropriate; it is no longer necessary
to refer to this class directly.
"""
def initialize(self, **kwargs: Any) -> None: # type: ignore
self.is_current = False
loop = asyncio.new_event_loop()
try:
super(AsyncIOLoop, self).initialize(loop, **kwargs)
except Exception:
# If initialize() does not succeed (taking ownership of the loop),
# we have to close it.
loop.close()
raise
def close(self, all_fds: bool = False) -> None:
if self.is_current:
self.clear_current()
super(AsyncIOLoop, self).close(all_fds=all_fds)
def make_current(self) -> None:
if not self.is_current:
try:
self.old_asyncio = asyncio.get_event_loop()
except (RuntimeError, AssertionError):
self.old_asyncio = None # type: ignore
self.is_current = True
asyncio.set_event_loop(self.asyncio_loop)
def _clear_current_hook(self) -> None:
if self.is_current:
asyncio.set_event_loop(self.old_asyncio)
self.is_current = False
def to_tornado_future(asyncio_future: asyncio.Future) -> asyncio.Future:
"""Convert an `asyncio.Future` to a `tornado.concurrent.Future`.
.. versionadded:: 4.1
.. deprecated:: 5.0
Tornado ``Futures`` have been merged with `asyncio.Future`,
so this method is now a no-op.
"""
return asyncio_future
def to_asyncio_future(tornado_future: asyncio.Future) -> asyncio.Future:
"""Convert a Tornado yieldable object to an `asyncio.Future`.
.. versionadded:: 4.1
.. versionchanged:: 4.3
Now accepts any yieldable object, not just
`tornado.concurrent.Future`.
.. deprecated:: 5.0
Tornado ``Futures`` have been merged with `asyncio.Future`,
so this method is now equivalent to `tornado.gen.convert_yielded`.
"""
return convert_yielded(tornado_future)
class AnyThreadEventLoopPolicy(asyncio.DefaultEventLoopPolicy): # type: ignore
"""Event loop policy that allows loop creation on any thread.
The default `asyncio` event loop policy only automatically creates
event loops in the main threads. Other threads must create event
loops explicitly or `asyncio.get_event_loop` (and therefore
`.IOLoop.current`) will fail. Installing this policy allows event
loops to be created automatically on any thread, matching the
behavior of Tornado versions prior to 5.0 (or 5.0 on Python 2).
Usage::
asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy())
.. versionadded:: 5.0
"""
def get_event_loop(self) -> asyncio.AbstractEventLoop:
try:
return super().get_event_loop()
except (RuntimeError, AssertionError):
# This was an AssertionError in python 3.4.2 (which ships with debian jessie)
# and changed to a RuntimeError in 3.4.3.
# "There is no current event loop in thread %r"
loop = self.new_event_loop()
self.set_event_loop(loop)
return loop

View File

@@ -0,0 +1,32 @@
#
# Copyright 2011 Facebook
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""Implementation of platform-specific functionality.
For each function or class described in `tornado.platform.interface`,
the appropriate platform-specific implementation exists in this module.
Most code that needs access to this functionality should do e.g.::
from tornado.platform.auto import set_close_exec
"""
import os
if os.name == "nt":
from tornado.platform.windows import set_close_exec
else:
from tornado.platform.posix import set_close_exec
__all__ = ["set_close_exec"]

View File

@@ -0,0 +1,89 @@
import pycares # type: ignore
import socket
from tornado.concurrent import Future
from tornado import gen
from tornado.ioloop import IOLoop
from tornado.netutil import Resolver, is_valid_ip
import typing
if typing.TYPE_CHECKING:
from typing import Generator, Any, List, Tuple, Dict # noqa: F401
class CaresResolver(Resolver):
"""Name resolver based on the c-ares library.
This is a non-blocking and non-threaded resolver. It may not produce
the same results as the system resolver, but can be used for non-blocking
resolution when threads cannot be used.
c-ares fails to resolve some names when ``family`` is ``AF_UNSPEC``,
so it is only recommended for use in ``AF_INET`` (i.e. IPv4). This is
the default for ``tornado.simple_httpclient``, but other libraries
may default to ``AF_UNSPEC``.
.. versionchanged:: 5.0
The ``io_loop`` argument (deprecated since version 4.1) has been removed.
"""
def initialize(self) -> None:
self.io_loop = IOLoop.current()
self.channel = pycares.Channel(sock_state_cb=self._sock_state_cb)
self.fds = {} # type: Dict[int, int]
def _sock_state_cb(self, fd: int, readable: bool, writable: bool) -> None:
state = (IOLoop.READ if readable else 0) | (IOLoop.WRITE if writable else 0)
if not state:
self.io_loop.remove_handler(fd)
del self.fds[fd]
elif fd in self.fds:
self.io_loop.update_handler(fd, state)
self.fds[fd] = state
else:
self.io_loop.add_handler(fd, self._handle_events, state)
self.fds[fd] = state
def _handle_events(self, fd: int, events: int) -> None:
read_fd = pycares.ARES_SOCKET_BAD
write_fd = pycares.ARES_SOCKET_BAD
if events & IOLoop.READ:
read_fd = fd
if events & IOLoop.WRITE:
write_fd = fd
self.channel.process_fd(read_fd, write_fd)
@gen.coroutine
def resolve(
self, host: str, port: int, family: int = 0
) -> "Generator[Any, Any, List[Tuple[int, Any]]]":
if is_valid_ip(host):
addresses = [host]
else:
# gethostbyname doesn't take callback as a kwarg
fut = Future() # type: Future[Tuple[Any, Any]]
self.channel.gethostbyname(
host, family, lambda result, error: fut.set_result((result, error))
)
result, error = yield fut
if error:
raise IOError(
"C-Ares returned error %s: %s while resolving %s"
% (error, pycares.errno.strerror(error), host)
)
addresses = result.addresses
addrinfo = []
for address in addresses:
if "." in address:
address_family = socket.AF_INET
elif ":" in address:
address_family = socket.AF_INET6
else:
address_family = socket.AF_UNSPEC
if family != socket.AF_UNSPEC and family != address_family:
raise IOError(
"Requested socket family %d but got %d" % (family, address_family)
)
addrinfo.append((typing.cast(int, address_family), (address, port)))
return addrinfo

View File

@@ -0,0 +1,26 @@
#
# Copyright 2011 Facebook
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""Interfaces for platform-specific functionality.
This module exists primarily for documentation purposes and as base classes
for other tornado.platform modules. Most code should import the appropriate
implementation from `tornado.platform.auto`.
"""
def set_close_exec(fd: int) -> None:
"""Sets the close-on-exec bit (``FD_CLOEXEC``)for a file descriptor."""
raise NotImplementedError()

View File

@@ -0,0 +1,29 @@
#
# Copyright 2011 Facebook
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""Posix implementations of platform-specific functionality."""
import fcntl
import os
def set_close_exec(fd: int) -> None:
flags = fcntl.fcntl(fd, fcntl.F_GETFD)
fcntl.fcntl(fd, fcntl.F_SETFD, flags | fcntl.FD_CLOEXEC)
def _set_nonblocking(fd: int) -> None:
flags = fcntl.fcntl(fd, fcntl.F_GETFL)
fcntl.fcntl(fd, fcntl.F_SETFL, flags | os.O_NONBLOCK)

View File

@@ -0,0 +1,131 @@
# Author: Ovidiu Predescu
# Date: July 2011
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""Bridges between the Twisted reactor and Tornado IOLoop.
This module lets you run applications and libraries written for
Twisted in a Tornado application. It can be used in two modes,
depending on which library's underlying event loop you want to use.
This module has been tested with Twisted versions 11.0.0 and newer.
"""
import socket
import sys
import twisted.internet.abstract # type: ignore
import twisted.internet.asyncioreactor # type: ignore
from twisted.internet.defer import Deferred # type: ignore
from twisted.python import failure # type: ignore
import twisted.names.cache # type: ignore
import twisted.names.client # type: ignore
import twisted.names.hosts # type: ignore
import twisted.names.resolve # type: ignore
from tornado.concurrent import Future, future_set_exc_info
from tornado.escape import utf8
from tornado import gen
from tornado.netutil import Resolver
import typing
if typing.TYPE_CHECKING:
from typing import Generator, Any, List, Tuple # noqa: F401
class TwistedResolver(Resolver):
"""Twisted-based asynchronous resolver.
This is a non-blocking and non-threaded resolver. It is
recommended only when threads cannot be used, since it has
limitations compared to the standard ``getaddrinfo``-based
`~tornado.netutil.Resolver` and
`~tornado.netutil.DefaultExecutorResolver`. Specifically, it returns at
most one result, and arguments other than ``host`` and ``family``
are ignored. It may fail to resolve when ``family`` is not
``socket.AF_UNSPEC``.
Requires Twisted 12.1 or newer.
.. versionchanged:: 5.0
The ``io_loop`` argument (deprecated since version 4.1) has been removed.
"""
def initialize(self) -> None:
# partial copy of twisted.names.client.createResolver, which doesn't
# allow for a reactor to be passed in.
self.reactor = twisted.internet.asyncioreactor.AsyncioSelectorReactor()
host_resolver = twisted.names.hosts.Resolver("/etc/hosts")
cache_resolver = twisted.names.cache.CacheResolver(reactor=self.reactor)
real_resolver = twisted.names.client.Resolver(
"/etc/resolv.conf", reactor=self.reactor
)
self.resolver = twisted.names.resolve.ResolverChain(
[host_resolver, cache_resolver, real_resolver]
)
@gen.coroutine
def resolve(
self, host: str, port: int, family: int = 0
) -> "Generator[Any, Any, List[Tuple[int, Any]]]":
# getHostByName doesn't accept IP addresses, so if the input
# looks like an IP address just return it immediately.
if twisted.internet.abstract.isIPAddress(host):
resolved = host
resolved_family = socket.AF_INET
elif twisted.internet.abstract.isIPv6Address(host):
resolved = host
resolved_family = socket.AF_INET6
else:
deferred = self.resolver.getHostByName(utf8(host))
fut = Future() # type: Future[Any]
deferred.addBoth(fut.set_result)
resolved = yield fut
if isinstance(resolved, failure.Failure):
try:
resolved.raiseException()
except twisted.names.error.DomainError as e:
raise IOError(e)
elif twisted.internet.abstract.isIPAddress(resolved):
resolved_family = socket.AF_INET
elif twisted.internet.abstract.isIPv6Address(resolved):
resolved_family = socket.AF_INET6
else:
resolved_family = socket.AF_UNSPEC
if family != socket.AF_UNSPEC and family != resolved_family:
raise Exception(
"Requested socket family %d but got %d" % (family, resolved_family)
)
result = [(typing.cast(int, resolved_family), (resolved, port))]
return result
if hasattr(gen.convert_yielded, "register"):
@gen.convert_yielded.register(Deferred) # type: ignore
def _(d: Deferred) -> Future:
f = Future() # type: Future[Any]
def errback(failure: failure.Failure) -> None:
try:
failure.raiseException()
# Should never happen, but just in case
raise Exception("errback called without error")
except:
future_set_exc_info(f, sys.exc_info())
d.addCallbacks(f.set_result, errback)
return f

View File

@@ -0,0 +1,22 @@
# NOTE: win32 support is currently experimental, and not recommended
# for production use.
import ctypes
import ctypes.wintypes
# See: http://msdn.microsoft.com/en-us/library/ms724935(VS.85).aspx
SetHandleInformation = ctypes.windll.kernel32.SetHandleInformation # type: ignore
SetHandleInformation.argtypes = (
ctypes.wintypes.HANDLE,
ctypes.wintypes.DWORD,
ctypes.wintypes.DWORD,
)
SetHandleInformation.restype = ctypes.wintypes.BOOL
HANDLE_FLAG_INHERIT = 0x00000001
def set_close_exec(fd: int) -> None:
success = SetHandleInformation(fd, HANDLE_FLAG_INHERIT, 0)
if not success:
raise ctypes.WinError() # type: ignore

View File

@@ -0,0 +1,373 @@
#
# Copyright 2011 Facebook
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""Utilities for working with multiple processes, including both forking
the server into multiple processes and managing subprocesses.
"""
import errno
import os
import multiprocessing
import signal
import subprocess
import sys
import time
from binascii import hexlify
from tornado.concurrent import (
Future,
future_set_result_unless_cancelled,
future_set_exception_unless_cancelled,
)
from tornado import ioloop
from tornado.iostream import PipeIOStream
from tornado.log import gen_log
from tornado.platform.auto import set_close_exec
from tornado.util import errno_from_exception
import typing
from typing import Tuple, Optional, Any, Callable
if typing.TYPE_CHECKING:
from typing import List # noqa: F401
# Re-export this exception for convenience.
CalledProcessError = subprocess.CalledProcessError
def cpu_count() -> int:
"""Returns the number of processors on this machine."""
if multiprocessing is None:
return 1
try:
return multiprocessing.cpu_count()
except NotImplementedError:
pass
try:
return os.sysconf("SC_NPROCESSORS_CONF")
except (AttributeError, ValueError):
pass
gen_log.error("Could not detect number of processors; assuming 1")
return 1
def _reseed_random() -> None:
if "random" not in sys.modules:
return
import random
# If os.urandom is available, this method does the same thing as
# random.seed (at least as of python 2.6). If os.urandom is not
# available, we mix in the pid in addition to a timestamp.
try:
seed = int(hexlify(os.urandom(16)), 16)
except NotImplementedError:
seed = int(time.time() * 1000) ^ os.getpid()
random.seed(seed)
def _pipe_cloexec() -> Tuple[int, int]:
r, w = os.pipe()
set_close_exec(r)
set_close_exec(w)
return r, w
_task_id = None
def fork_processes(num_processes: Optional[int], max_restarts: int = None) -> int:
"""Starts multiple worker processes.
If ``num_processes`` is None or <= 0, we detect the number of cores
available on this machine and fork that number of child
processes. If ``num_processes`` is given and > 0, we fork that
specific number of sub-processes.
Since we use processes and not threads, there is no shared memory
between any server code.
Note that multiple processes are not compatible with the autoreload
module (or the ``autoreload=True`` option to `tornado.web.Application`
which defaults to True when ``debug=True``).
When using multiple processes, no IOLoops can be created or
referenced until after the call to ``fork_processes``.
In each child process, ``fork_processes`` returns its *task id*, a
number between 0 and ``num_processes``. Processes that exit
abnormally (due to a signal or non-zero exit status) are restarted
with the same id (up to ``max_restarts`` times). In the parent
process, ``fork_processes`` returns None if all child processes
have exited normally, but will otherwise only exit by throwing an
exception.
max_restarts defaults to 100.
"""
if max_restarts is None:
max_restarts = 100
global _task_id
assert _task_id is None
if num_processes is None or num_processes <= 0:
num_processes = cpu_count()
gen_log.info("Starting %d processes", num_processes)
children = {}
def start_child(i: int) -> Optional[int]:
pid = os.fork()
if pid == 0:
# child process
_reseed_random()
global _task_id
_task_id = i
return i
else:
children[pid] = i
return None
for i in range(num_processes):
id = start_child(i)
if id is not None:
return id
num_restarts = 0
while children:
try:
pid, status = os.wait()
except OSError as e:
if errno_from_exception(e) == errno.EINTR:
continue
raise
if pid not in children:
continue
id = children.pop(pid)
if os.WIFSIGNALED(status):
gen_log.warning(
"child %d (pid %d) killed by signal %d, restarting",
id,
pid,
os.WTERMSIG(status),
)
elif os.WEXITSTATUS(status) != 0:
gen_log.warning(
"child %d (pid %d) exited with status %d, restarting",
id,
pid,
os.WEXITSTATUS(status),
)
else:
gen_log.info("child %d (pid %d) exited normally", id, pid)
continue
num_restarts += 1
if num_restarts > max_restarts:
raise RuntimeError("Too many child restarts, giving up")
new_id = start_child(id)
if new_id is not None:
return new_id
# All child processes exited cleanly, so exit the master process
# instead of just returning to right after the call to
# fork_processes (which will probably just start up another IOLoop
# unless the caller checks the return value).
sys.exit(0)
def task_id() -> Optional[int]:
"""Returns the current task id, if any.
Returns None if this process was not created by `fork_processes`.
"""
global _task_id
return _task_id
class Subprocess(object):
"""Wraps ``subprocess.Popen`` with IOStream support.
The constructor is the same as ``subprocess.Popen`` with the following
additions:
* ``stdin``, ``stdout``, and ``stderr`` may have the value
``tornado.process.Subprocess.STREAM``, which will make the corresponding
attribute of the resulting Subprocess a `.PipeIOStream`. If this option
is used, the caller is responsible for closing the streams when done
with them.
The ``Subprocess.STREAM`` option and the ``set_exit_callback`` and
``wait_for_exit`` methods do not work on Windows. There is
therefore no reason to use this class instead of
``subprocess.Popen`` on that platform.
.. versionchanged:: 5.0
The ``io_loop`` argument (deprecated since version 4.1) has been removed.
"""
STREAM = object()
_initialized = False
_waiting = {} # type: ignore
_old_sigchld = None
def __init__(self, *args: Any, **kwargs: Any) -> None:
self.io_loop = ioloop.IOLoop.current()
# All FDs we create should be closed on error; those in to_close
# should be closed in the parent process on success.
pipe_fds = [] # type: List[int]
to_close = [] # type: List[int]
if kwargs.get("stdin") is Subprocess.STREAM:
in_r, in_w = _pipe_cloexec()
kwargs["stdin"] = in_r
pipe_fds.extend((in_r, in_w))
to_close.append(in_r)
self.stdin = PipeIOStream(in_w)
if kwargs.get("stdout") is Subprocess.STREAM:
out_r, out_w = _pipe_cloexec()
kwargs["stdout"] = out_w
pipe_fds.extend((out_r, out_w))
to_close.append(out_w)
self.stdout = PipeIOStream(out_r)
if kwargs.get("stderr") is Subprocess.STREAM:
err_r, err_w = _pipe_cloexec()
kwargs["stderr"] = err_w
pipe_fds.extend((err_r, err_w))
to_close.append(err_w)
self.stderr = PipeIOStream(err_r)
try:
self.proc = subprocess.Popen(*args, **kwargs)
except:
for fd in pipe_fds:
os.close(fd)
raise
for fd in to_close:
os.close(fd)
self.pid = self.proc.pid
for attr in ["stdin", "stdout", "stderr"]:
if not hasattr(self, attr): # don't clobber streams set above
setattr(self, attr, getattr(self.proc, attr))
self._exit_callback = None # type: Optional[Callable[[int], None]]
self.returncode = None # type: Optional[int]
def set_exit_callback(self, callback: Callable[[int], None]) -> None:
"""Runs ``callback`` when this process exits.
The callback takes one argument, the return code of the process.
This method uses a ``SIGCHLD`` handler, which is a global setting
and may conflict if you have other libraries trying to handle the
same signal. If you are using more than one ``IOLoop`` it may
be necessary to call `Subprocess.initialize` first to designate
one ``IOLoop`` to run the signal handlers.
In many cases a close callback on the stdout or stderr streams
can be used as an alternative to an exit callback if the
signal handler is causing a problem.
"""
self._exit_callback = callback
Subprocess.initialize()
Subprocess._waiting[self.pid] = self
Subprocess._try_cleanup_process(self.pid)
def wait_for_exit(self, raise_error: bool = True) -> "Future[int]":
"""Returns a `.Future` which resolves when the process exits.
Usage::
ret = yield proc.wait_for_exit()
This is a coroutine-friendly alternative to `set_exit_callback`
(and a replacement for the blocking `subprocess.Popen.wait`).
By default, raises `subprocess.CalledProcessError` if the process
has a non-zero exit status. Use ``wait_for_exit(raise_error=False)``
to suppress this behavior and return the exit status without raising.
.. versionadded:: 4.2
"""
future = Future() # type: Future[int]
def callback(ret: int) -> None:
if ret != 0 and raise_error:
# Unfortunately we don't have the original args any more.
future_set_exception_unless_cancelled(
future, CalledProcessError(ret, "unknown")
)
else:
future_set_result_unless_cancelled(future, ret)
self.set_exit_callback(callback)
return future
@classmethod
def initialize(cls) -> None:
"""Initializes the ``SIGCHLD`` handler.
The signal handler is run on an `.IOLoop` to avoid locking issues.
Note that the `.IOLoop` used for signal handling need not be the
same one used by individual Subprocess objects (as long as the
``IOLoops`` are each running in separate threads).
.. versionchanged:: 5.0
The ``io_loop`` argument (deprecated since version 4.1) has been
removed.
"""
if cls._initialized:
return
io_loop = ioloop.IOLoop.current()
cls._old_sigchld = signal.signal(
signal.SIGCHLD,
lambda sig, frame: io_loop.add_callback_from_signal(cls._cleanup),
)
cls._initialized = True
@classmethod
def uninitialize(cls) -> None:
"""Removes the ``SIGCHLD`` handler."""
if not cls._initialized:
return
signal.signal(signal.SIGCHLD, cls._old_sigchld)
cls._initialized = False
@classmethod
def _cleanup(cls) -> None:
for pid in list(cls._waiting.keys()): # make a copy
cls._try_cleanup_process(pid)
@classmethod
def _try_cleanup_process(cls, pid: int) -> None:
try:
ret_pid, status = os.waitpid(pid, os.WNOHANG)
except OSError as e:
if errno_from_exception(e) == errno.ECHILD:
return
if ret_pid == 0:
return
assert ret_pid == pid
subproc = cls._waiting.pop(pid)
subproc.io_loop.add_callback_from_signal(subproc._set_returncode, status)
def _set_returncode(self, status: int) -> None:
if os.WIFSIGNALED(status):
self.returncode = -os.WTERMSIG(status)
else:
assert os.WIFEXITED(status)
self.returncode = os.WEXITSTATUS(status)
# We've taken over wait() duty from the subprocess.Popen
# object. If we don't inform it of the process's return code,
# it will log a warning at destruction in python 3.6+.
self.proc.returncode = self.returncode
if self._exit_callback:
callback = self._exit_callback
self._exit_callback = None
callback(self.returncode)

View File

@@ -0,0 +1,410 @@
# Copyright 2015 The Tornado Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""Asynchronous queues for coroutines. These classes are very similar
to those provided in the standard library's `asyncio package
<https://docs.python.org/3/library/asyncio-queue.html>`_.
.. warning::
Unlike the standard library's `queue` module, the classes defined here
are *not* thread-safe. To use these queues from another thread,
use `.IOLoop.add_callback` to transfer control to the `.IOLoop` thread
before calling any queue methods.
"""
import collections
import datetime
import heapq
from tornado import gen, ioloop
from tornado.concurrent import Future, future_set_result_unless_cancelled
from tornado.locks import Event
from typing import Union, TypeVar, Generic, Awaitable
import typing
if typing.TYPE_CHECKING:
from typing import Deque, Tuple, List, Any # noqa: F401
_T = TypeVar("_T")
__all__ = ["Queue", "PriorityQueue", "LifoQueue", "QueueFull", "QueueEmpty"]
class QueueEmpty(Exception):
"""Raised by `.Queue.get_nowait` when the queue has no items."""
pass
class QueueFull(Exception):
"""Raised by `.Queue.put_nowait` when a queue is at its maximum size."""
pass
def _set_timeout(
future: Future, timeout: Union[None, float, datetime.timedelta]
) -> None:
if timeout:
def on_timeout() -> None:
if not future.done():
future.set_exception(gen.TimeoutError())
io_loop = ioloop.IOLoop.current()
timeout_handle = io_loop.add_timeout(timeout, on_timeout)
future.add_done_callback(lambda _: io_loop.remove_timeout(timeout_handle))
class _QueueIterator(Generic[_T]):
def __init__(self, q: "Queue[_T]") -> None:
self.q = q
def __anext__(self) -> Awaitable[_T]:
return self.q.get()
class Queue(Generic[_T]):
"""Coordinate producer and consumer coroutines.
If maxsize is 0 (the default) the queue size is unbounded.
.. testcode::
from tornado import gen
from tornado.ioloop import IOLoop
from tornado.queues import Queue
q = Queue(maxsize=2)
async def consumer():
async for item in q:
try:
print('Doing work on %s' % item)
await gen.sleep(0.01)
finally:
q.task_done()
async def producer():
for item in range(5):
await q.put(item)
print('Put %s' % item)
async def main():
# Start consumer without waiting (since it never finishes).
IOLoop.current().spawn_callback(consumer)
await producer() # Wait for producer to put all tasks.
await q.join() # Wait for consumer to finish all tasks.
print('Done')
IOLoop.current().run_sync(main)
.. testoutput::
Put 0
Put 1
Doing work on 0
Put 2
Doing work on 1
Put 3
Doing work on 2
Put 4
Doing work on 3
Doing work on 4
Done
In versions of Python without native coroutines (before 3.5),
``consumer()`` could be written as::
@gen.coroutine
def consumer():
while True:
item = yield q.get()
try:
print('Doing work on %s' % item)
yield gen.sleep(0.01)
finally:
q.task_done()
.. versionchanged:: 4.3
Added ``async for`` support in Python 3.5.
"""
# Exact type depends on subclass. Could be another generic
# parameter and use protocols to be more precise here.
_queue = None # type: Any
def __init__(self, maxsize: int = 0) -> None:
if maxsize is None:
raise TypeError("maxsize can't be None")
if maxsize < 0:
raise ValueError("maxsize can't be negative")
self._maxsize = maxsize
self._init()
self._getters = collections.deque([]) # type: Deque[Future[_T]]
self._putters = collections.deque([]) # type: Deque[Tuple[_T, Future[None]]]
self._unfinished_tasks = 0
self._finished = Event()
self._finished.set()
@property
def maxsize(self) -> int:
"""Number of items allowed in the queue."""
return self._maxsize
def qsize(self) -> int:
"""Number of items in the queue."""
return len(self._queue)
def empty(self) -> bool:
return not self._queue
def full(self) -> bool:
if self.maxsize == 0:
return False
else:
return self.qsize() >= self.maxsize
def put(
self, item: _T, timeout: Union[float, datetime.timedelta] = None
) -> "Future[None]":
"""Put an item into the queue, perhaps waiting until there is room.
Returns a Future, which raises `tornado.util.TimeoutError` after a
timeout.
``timeout`` may be a number denoting a time (on the same
scale as `tornado.ioloop.IOLoop.time`, normally `time.time`), or a
`datetime.timedelta` object for a deadline relative to the
current time.
"""
future = Future() # type: Future[None]
try:
self.put_nowait(item)
except QueueFull:
self._putters.append((item, future))
_set_timeout(future, timeout)
else:
future.set_result(None)
return future
def put_nowait(self, item: _T) -> None:
"""Put an item into the queue without blocking.
If no free slot is immediately available, raise `QueueFull`.
"""
self._consume_expired()
if self._getters:
assert self.empty(), "queue non-empty, why are getters waiting?"
getter = self._getters.popleft()
self.__put_internal(item)
future_set_result_unless_cancelled(getter, self._get())
elif self.full():
raise QueueFull
else:
self.__put_internal(item)
def get(self, timeout: Union[float, datetime.timedelta] = None) -> Awaitable[_T]:
"""Remove and return an item from the queue.
Returns an awaitable which resolves once an item is available, or raises
`tornado.util.TimeoutError` after a timeout.
``timeout`` may be a number denoting a time (on the same
scale as `tornado.ioloop.IOLoop.time`, normally `time.time`), or a
`datetime.timedelta` object for a deadline relative to the
current time.
.. note::
The ``timeout`` argument of this method differs from that
of the standard library's `queue.Queue.get`. That method
interprets numeric values as relative timeouts; this one
interprets them as absolute deadlines and requires
``timedelta`` objects for relative timeouts (consistent
with other timeouts in Tornado).
"""
future = Future() # type: Future[_T]
try:
future.set_result(self.get_nowait())
except QueueEmpty:
self._getters.append(future)
_set_timeout(future, timeout)
return future
def get_nowait(self) -> _T:
"""Remove and return an item from the queue without blocking.
Return an item if one is immediately available, else raise
`QueueEmpty`.
"""
self._consume_expired()
if self._putters:
assert self.full(), "queue not full, why are putters waiting?"
item, putter = self._putters.popleft()
self.__put_internal(item)
future_set_result_unless_cancelled(putter, None)
return self._get()
elif self.qsize():
return self._get()
else:
raise QueueEmpty
def task_done(self) -> None:
"""Indicate that a formerly enqueued task is complete.
Used by queue consumers. For each `.get` used to fetch a task, a
subsequent call to `.task_done` tells the queue that the processing
on the task is complete.
If a `.join` is blocking, it resumes when all items have been
processed; that is, when every `.put` is matched by a `.task_done`.
Raises `ValueError` if called more times than `.put`.
"""
if self._unfinished_tasks <= 0:
raise ValueError("task_done() called too many times")
self._unfinished_tasks -= 1
if self._unfinished_tasks == 0:
self._finished.set()
def join(self, timeout: Union[float, datetime.timedelta] = None) -> Awaitable[None]:
"""Block until all items in the queue are processed.
Returns an awaitable, which raises `tornado.util.TimeoutError` after a
timeout.
"""
return self._finished.wait(timeout)
def __aiter__(self) -> _QueueIterator[_T]:
return _QueueIterator(self)
# These three are overridable in subclasses.
def _init(self) -> None:
self._queue = collections.deque()
def _get(self) -> _T:
return self._queue.popleft()
def _put(self, item: _T) -> None:
self._queue.append(item)
# End of the overridable methods.
def __put_internal(self, item: _T) -> None:
self._unfinished_tasks += 1
self._finished.clear()
self._put(item)
def _consume_expired(self) -> None:
# Remove timed-out waiters.
while self._putters and self._putters[0][1].done():
self._putters.popleft()
while self._getters and self._getters[0].done():
self._getters.popleft()
def __repr__(self) -> str:
return "<%s at %s %s>" % (type(self).__name__, hex(id(self)), self._format())
def __str__(self) -> str:
return "<%s %s>" % (type(self).__name__, self._format())
def _format(self) -> str:
result = "maxsize=%r" % (self.maxsize,)
if getattr(self, "_queue", None):
result += " queue=%r" % self._queue
if self._getters:
result += " getters[%s]" % len(self._getters)
if self._putters:
result += " putters[%s]" % len(self._putters)
if self._unfinished_tasks:
result += " tasks=%s" % self._unfinished_tasks
return result
class PriorityQueue(Queue):
"""A `.Queue` that retrieves entries in priority order, lowest first.
Entries are typically tuples like ``(priority number, data)``.
.. testcode::
from tornado.queues import PriorityQueue
q = PriorityQueue()
q.put((1, 'medium-priority item'))
q.put((0, 'high-priority item'))
q.put((10, 'low-priority item'))
print(q.get_nowait())
print(q.get_nowait())
print(q.get_nowait())
.. testoutput::
(0, 'high-priority item')
(1, 'medium-priority item')
(10, 'low-priority item')
"""
def _init(self) -> None:
self._queue = []
def _put(self, item: _T) -> None:
heapq.heappush(self._queue, item)
def _get(self) -> _T:
return heapq.heappop(self._queue)
class LifoQueue(Queue):
"""A `.Queue` that retrieves the most recently put items first.
.. testcode::
from tornado.queues import LifoQueue
q = LifoQueue()
q.put(3)
q.put(2)
q.put(1)
print(q.get_nowait())
print(q.get_nowait())
print(q.get_nowait())
.. testoutput::
1
2
3
"""
def _init(self) -> None:
self._queue = []
def _put(self, item: _T) -> None:
self._queue.append(item)
def _get(self) -> _T:
return self._queue.pop()

View File

@@ -0,0 +1,711 @@
# Copyright 2015 The Tornado Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""Flexible routing implementation.
Tornado routes HTTP requests to appropriate handlers using `Router`
class implementations. The `tornado.web.Application` class is a
`Router` implementation and may be used directly, or the classes in
this module may be used for additional flexibility. The `RuleRouter`
class can match on more criteria than `.Application`, or the `Router`
interface can be subclassed for maximum customization.
`Router` interface extends `~.httputil.HTTPServerConnectionDelegate`
to provide additional routing capabilities. This also means that any
`Router` implementation can be used directly as a ``request_callback``
for `~.httpserver.HTTPServer` constructor.
`Router` subclass must implement a ``find_handler`` method to provide
a suitable `~.httputil.HTTPMessageDelegate` instance to handle the
request:
.. code-block:: python
class CustomRouter(Router):
def find_handler(self, request, **kwargs):
# some routing logic providing a suitable HTTPMessageDelegate instance
return MessageDelegate(request.connection)
class MessageDelegate(HTTPMessageDelegate):
def __init__(self, connection):
self.connection = connection
def finish(self):
self.connection.write_headers(
ResponseStartLine("HTTP/1.1", 200, "OK"),
HTTPHeaders({"Content-Length": "2"}),
b"OK")
self.connection.finish()
router = CustomRouter()
server = HTTPServer(router)
The main responsibility of `Router` implementation is to provide a
mapping from a request to `~.httputil.HTTPMessageDelegate` instance
that will handle this request. In the example above we can see that
routing is possible even without instantiating an `~.web.Application`.
For routing to `~.web.RequestHandler` implementations we need an
`~.web.Application` instance. `~.web.Application.get_handler_delegate`
provides a convenient way to create `~.httputil.HTTPMessageDelegate`
for a given request and `~.web.RequestHandler`.
Here is a simple example of how we can we route to
`~.web.RequestHandler` subclasses by HTTP method:
.. code-block:: python
resources = {}
class GetResource(RequestHandler):
def get(self, path):
if path not in resources:
raise HTTPError(404)
self.finish(resources[path])
class PostResource(RequestHandler):
def post(self, path):
resources[path] = self.request.body
class HTTPMethodRouter(Router):
def __init__(self, app):
self.app = app
def find_handler(self, request, **kwargs):
handler = GetResource if request.method == "GET" else PostResource
return self.app.get_handler_delegate(request, handler, path_args=[request.path])
router = HTTPMethodRouter(Application())
server = HTTPServer(router)
`ReversibleRouter` interface adds the ability to distinguish between
the routes and reverse them to the original urls using route's name
and additional arguments. `~.web.Application` is itself an
implementation of `ReversibleRouter` class.
`RuleRouter` and `ReversibleRuleRouter` are implementations of
`Router` and `ReversibleRouter` interfaces and can be used for
creating rule-based routing configurations.
Rules are instances of `Rule` class. They contain a `Matcher`, which
provides the logic for determining whether the rule is a match for a
particular request and a target, which can be one of the following.
1) An instance of `~.httputil.HTTPServerConnectionDelegate`:
.. code-block:: python
router = RuleRouter([
Rule(PathMatches("/handler"), ConnectionDelegate()),
# ... more rules
])
class ConnectionDelegate(HTTPServerConnectionDelegate):
def start_request(self, server_conn, request_conn):
return MessageDelegate(request_conn)
2) A callable accepting a single argument of `~.httputil.HTTPServerRequest` type:
.. code-block:: python
router = RuleRouter([
Rule(PathMatches("/callable"), request_callable)
])
def request_callable(request):
request.write(b"HTTP/1.1 200 OK\\r\\nContent-Length: 2\\r\\n\\r\\nOK")
request.finish()
3) Another `Router` instance:
.. code-block:: python
router = RuleRouter([
Rule(PathMatches("/router.*"), CustomRouter())
])
Of course a nested `RuleRouter` or a `~.web.Application` is allowed:
.. code-block:: python
router = RuleRouter([
Rule(HostMatches("example.com"), RuleRouter([
Rule(PathMatches("/app1/.*"), Application([(r"/app1/handler", Handler)]))),
]))
])
server = HTTPServer(router)
In the example below `RuleRouter` is used to route between applications:
.. code-block:: python
app1 = Application([
(r"/app1/handler", Handler1),
# other handlers ...
])
app2 = Application([
(r"/app2/handler", Handler2),
# other handlers ...
])
router = RuleRouter([
Rule(PathMatches("/app1.*"), app1),
Rule(PathMatches("/app2.*"), app2)
])
server = HTTPServer(router)
For more information on application-level routing see docs for `~.web.Application`.
.. versionadded:: 4.5
"""
import re
from functools import partial
from tornado import httputil
from tornado.httpserver import _CallableAdapter
from tornado.escape import url_escape, url_unescape, utf8
from tornado.log import app_log
from tornado.util import basestring_type, import_object, re_unescape, unicode_type
from typing import Any, Union, Optional, Awaitable, List, Dict, Pattern, Tuple, overload
class Router(httputil.HTTPServerConnectionDelegate):
"""Abstract router interface."""
def find_handler(
self, request: httputil.HTTPServerRequest, **kwargs: Any
) -> Optional[httputil.HTTPMessageDelegate]:
"""Must be implemented to return an appropriate instance of `~.httputil.HTTPMessageDelegate`
that can serve the request.
Routing implementations may pass additional kwargs to extend the routing logic.
:arg httputil.HTTPServerRequest request: current HTTP request.
:arg kwargs: additional keyword arguments passed by routing implementation.
:returns: an instance of `~.httputil.HTTPMessageDelegate` that will be used to
process the request.
"""
raise NotImplementedError()
def start_request(
self, server_conn: object, request_conn: httputil.HTTPConnection
) -> httputil.HTTPMessageDelegate:
return _RoutingDelegate(self, server_conn, request_conn)
class ReversibleRouter(Router):
"""Abstract router interface for routers that can handle named routes
and support reversing them to original urls.
"""
def reverse_url(self, name: str, *args: Any) -> Optional[str]:
"""Returns url string for a given route name and arguments
or ``None`` if no match is found.
:arg str name: route name.
:arg args: url parameters.
:returns: parametrized url string for a given route name (or ``None``).
"""
raise NotImplementedError()
class _RoutingDelegate(httputil.HTTPMessageDelegate):
def __init__(
self, router: Router, server_conn: object, request_conn: httputil.HTTPConnection
) -> None:
self.server_conn = server_conn
self.request_conn = request_conn
self.delegate = None # type: Optional[httputil.HTTPMessageDelegate]
self.router = router # type: Router
def headers_received(
self,
start_line: Union[httputil.RequestStartLine, httputil.ResponseStartLine],
headers: httputil.HTTPHeaders,
) -> Optional[Awaitable[None]]:
assert isinstance(start_line, httputil.RequestStartLine)
request = httputil.HTTPServerRequest(
connection=self.request_conn,
server_connection=self.server_conn,
start_line=start_line,
headers=headers,
)
self.delegate = self.router.find_handler(request)
if self.delegate is None:
app_log.debug(
"Delegate for %s %s request not found",
start_line.method,
start_line.path,
)
self.delegate = _DefaultMessageDelegate(self.request_conn)
return self.delegate.headers_received(start_line, headers)
def data_received(self, chunk: bytes) -> Optional[Awaitable[None]]:
assert self.delegate is not None
return self.delegate.data_received(chunk)
def finish(self) -> None:
assert self.delegate is not None
self.delegate.finish()
def on_connection_close(self) -> None:
assert self.delegate is not None
self.delegate.on_connection_close()
class _DefaultMessageDelegate(httputil.HTTPMessageDelegate):
def __init__(self, connection: httputil.HTTPConnection) -> None:
self.connection = connection
def finish(self) -> None:
self.connection.write_headers(
httputil.ResponseStartLine("HTTP/1.1", 404, "Not Found"),
httputil.HTTPHeaders(),
)
self.connection.finish()
# _RuleList can either contain pre-constructed Rules or a sequence of
# arguments to be passed to the Rule constructor.
_RuleList = List[
Union[
"Rule",
List[Any], # Can't do detailed typechecking of lists.
Tuple[Union[str, "Matcher"], Any],
Tuple[Union[str, "Matcher"], Any, Dict[str, Any]],
Tuple[Union[str, "Matcher"], Any, Dict[str, Any], str],
]
]
class RuleRouter(Router):
"""Rule-based router implementation."""
def __init__(self, rules: _RuleList = None) -> None:
"""Constructs a router from an ordered list of rules::
RuleRouter([
Rule(PathMatches("/handler"), Target),
# ... more rules
])
You can also omit explicit `Rule` constructor and use tuples of arguments::
RuleRouter([
(PathMatches("/handler"), Target),
])
`PathMatches` is a default matcher, so the example above can be simplified::
RuleRouter([
("/handler", Target),
])
In the examples above, ``Target`` can be a nested `Router` instance, an instance of
`~.httputil.HTTPServerConnectionDelegate` or an old-style callable,
accepting a request argument.
:arg rules: a list of `Rule` instances or tuples of `Rule`
constructor arguments.
"""
self.rules = [] # type: List[Rule]
if rules:
self.add_rules(rules)
def add_rules(self, rules: _RuleList) -> None:
"""Appends new rules to the router.
:arg rules: a list of Rule instances (or tuples of arguments, which are
passed to Rule constructor).
"""
for rule in rules:
if isinstance(rule, (tuple, list)):
assert len(rule) in (2, 3, 4)
if isinstance(rule[0], basestring_type):
rule = Rule(PathMatches(rule[0]), *rule[1:])
else:
rule = Rule(*rule)
self.rules.append(self.process_rule(rule))
def process_rule(self, rule: "Rule") -> "Rule":
"""Override this method for additional preprocessing of each rule.
:arg Rule rule: a rule to be processed.
:returns: the same or modified Rule instance.
"""
return rule
def find_handler(
self, request: httputil.HTTPServerRequest, **kwargs: Any
) -> Optional[httputil.HTTPMessageDelegate]:
for rule in self.rules:
target_params = rule.matcher.match(request)
if target_params is not None:
if rule.target_kwargs:
target_params["target_kwargs"] = rule.target_kwargs
delegate = self.get_target_delegate(
rule.target, request, **target_params
)
if delegate is not None:
return delegate
return None
def get_target_delegate(
self, target: Any, request: httputil.HTTPServerRequest, **target_params: Any
) -> Optional[httputil.HTTPMessageDelegate]:
"""Returns an instance of `~.httputil.HTTPMessageDelegate` for a
Rule's target. This method is called by `~.find_handler` and can be
extended to provide additional target types.
:arg target: a Rule's target.
:arg httputil.HTTPServerRequest request: current request.
:arg target_params: additional parameters that can be useful
for `~.httputil.HTTPMessageDelegate` creation.
"""
if isinstance(target, Router):
return target.find_handler(request, **target_params)
elif isinstance(target, httputil.HTTPServerConnectionDelegate):
assert request.connection is not None
return target.start_request(request.server_connection, request.connection)
elif callable(target):
assert request.connection is not None
return _CallableAdapter(
partial(target, **target_params), request.connection
)
return None
class ReversibleRuleRouter(ReversibleRouter, RuleRouter):
"""A rule-based router that implements ``reverse_url`` method.
Each rule added to this router may have a ``name`` attribute that can be
used to reconstruct an original uri. The actual reconstruction takes place
in a rule's matcher (see `Matcher.reverse`).
"""
def __init__(self, rules: _RuleList = None) -> None:
self.named_rules = {} # type: Dict[str, Any]
super(ReversibleRuleRouter, self).__init__(rules)
def process_rule(self, rule: "Rule") -> "Rule":
rule = super(ReversibleRuleRouter, self).process_rule(rule)
if rule.name:
if rule.name in self.named_rules:
app_log.warning(
"Multiple handlers named %s; replacing previous value", rule.name
)
self.named_rules[rule.name] = rule
return rule
def reverse_url(self, name: str, *args: Any) -> Optional[str]:
if name in self.named_rules:
return self.named_rules[name].matcher.reverse(*args)
for rule in self.rules:
if isinstance(rule.target, ReversibleRouter):
reversed_url = rule.target.reverse_url(name, *args)
if reversed_url is not None:
return reversed_url
return None
class Rule(object):
"""A routing rule."""
def __init__(
self,
matcher: "Matcher",
target: Any,
target_kwargs: Dict[str, Any] = None,
name: str = None,
) -> None:
"""Constructs a Rule instance.
:arg Matcher matcher: a `Matcher` instance used for determining
whether the rule should be considered a match for a specific
request.
:arg target: a Rule's target (typically a ``RequestHandler`` or
`~.httputil.HTTPServerConnectionDelegate` subclass or even a nested `Router`,
depending on routing implementation).
:arg dict target_kwargs: a dict of parameters that can be useful
at the moment of target instantiation (for example, ``status_code``
for a ``RequestHandler`` subclass). They end up in
``target_params['target_kwargs']`` of `RuleRouter.get_target_delegate`
method.
:arg str name: the name of the rule that can be used to find it
in `ReversibleRouter.reverse_url` implementation.
"""
if isinstance(target, str):
# import the Module and instantiate the class
# Must be a fully qualified name (module.ClassName)
target = import_object(target)
self.matcher = matcher # type: Matcher
self.target = target
self.target_kwargs = target_kwargs if target_kwargs else {}
self.name = name
def reverse(self, *args: Any) -> Optional[str]:
return self.matcher.reverse(*args)
def __repr__(self) -> str:
return "%s(%r, %s, kwargs=%r, name=%r)" % (
self.__class__.__name__,
self.matcher,
self.target,
self.target_kwargs,
self.name,
)
class Matcher(object):
"""Represents a matcher for request features."""
def match(self, request: httputil.HTTPServerRequest) -> Optional[Dict[str, Any]]:
"""Matches current instance against the request.
:arg httputil.HTTPServerRequest request: current HTTP request
:returns: a dict of parameters to be passed to the target handler
(for example, ``handler_kwargs``, ``path_args``, ``path_kwargs``
can be passed for proper `~.web.RequestHandler` instantiation).
An empty dict is a valid (and common) return value to indicate a match
when the argument-passing features are not used.
``None`` must be returned to indicate that there is no match."""
raise NotImplementedError()
def reverse(self, *args: Any) -> Optional[str]:
"""Reconstructs full url from matcher instance and additional arguments."""
return None
class AnyMatches(Matcher):
"""Matches any request."""
def match(self, request: httputil.HTTPServerRequest) -> Optional[Dict[str, Any]]:
return {}
class HostMatches(Matcher):
"""Matches requests from hosts specified by ``host_pattern`` regex."""
def __init__(self, host_pattern: Union[str, Pattern]) -> None:
if isinstance(host_pattern, basestring_type):
if not host_pattern.endswith("$"):
host_pattern += "$"
self.host_pattern = re.compile(host_pattern)
else:
self.host_pattern = host_pattern
def match(self, request: httputil.HTTPServerRequest) -> Optional[Dict[str, Any]]:
if self.host_pattern.match(request.host_name):
return {}
return None
class DefaultHostMatches(Matcher):
"""Matches requests from host that is equal to application's default_host.
Always returns no match if ``X-Real-Ip`` header is present.
"""
def __init__(self, application: Any, host_pattern: Pattern) -> None:
self.application = application
self.host_pattern = host_pattern
def match(self, request: httputil.HTTPServerRequest) -> Optional[Dict[str, Any]]:
# Look for default host if not behind load balancer (for debugging)
if "X-Real-Ip" not in request.headers:
if self.host_pattern.match(self.application.default_host):
return {}
return None
class PathMatches(Matcher):
"""Matches requests with paths specified by ``path_pattern`` regex."""
def __init__(self, path_pattern: Union[str, Pattern]) -> None:
if isinstance(path_pattern, basestring_type):
if not path_pattern.endswith("$"):
path_pattern += "$"
self.regex = re.compile(path_pattern)
else:
self.regex = path_pattern
assert len(self.regex.groupindex) in (0, self.regex.groups), (
"groups in url regexes must either be all named or all "
"positional: %r" % self.regex.pattern
)
self._path, self._group_count = self._find_groups()
def match(self, request: httputil.HTTPServerRequest) -> Optional[Dict[str, Any]]:
match = self.regex.match(request.path)
if match is None:
return None
if not self.regex.groups:
return {}
path_args = [] # type: List[bytes]
path_kwargs = {} # type: Dict[str, bytes]
# Pass matched groups to the handler. Since
# match.groups() includes both named and
# unnamed groups, we want to use either groups
# or groupdict but not both.
if self.regex.groupindex:
path_kwargs = dict(
(str(k), _unquote_or_none(v)) for (k, v) in match.groupdict().items()
)
else:
path_args = [_unquote_or_none(s) for s in match.groups()]
return dict(path_args=path_args, path_kwargs=path_kwargs)
def reverse(self, *args: Any) -> Optional[str]:
if self._path is None:
raise ValueError("Cannot reverse url regex " + self.regex.pattern)
assert len(args) == self._group_count, (
"required number of arguments " "not found"
)
if not len(args):
return self._path
converted_args = []
for a in args:
if not isinstance(a, (unicode_type, bytes)):
a = str(a)
converted_args.append(url_escape(utf8(a), plus=False))
return self._path % tuple(converted_args)
def _find_groups(self) -> Tuple[Optional[str], Optional[int]]:
"""Returns a tuple (reverse string, group count) for a url.
For example: Given the url pattern /([0-9]{4})/([a-z-]+)/, this method
would return ('/%s/%s/', 2).
"""
pattern = self.regex.pattern
if pattern.startswith("^"):
pattern = pattern[1:]
if pattern.endswith("$"):
pattern = pattern[:-1]
if self.regex.groups != pattern.count("("):
# The pattern is too complicated for our simplistic matching,
# so we can't support reversing it.
return None, None
pieces = []
for fragment in pattern.split("("):
if ")" in fragment:
paren_loc = fragment.index(")")
if paren_loc >= 0:
pieces.append("%s" + fragment[paren_loc + 1 :])
else:
try:
unescaped_fragment = re_unescape(fragment)
except ValueError:
# If we can't unescape part of it, we can't
# reverse this url.
return (None, None)
pieces.append(unescaped_fragment)
return "".join(pieces), self.regex.groups
class URLSpec(Rule):
"""Specifies mappings between URLs and handlers.
.. versionchanged: 4.5
`URLSpec` is now a subclass of a `Rule` with `PathMatches` matcher and is preserved for
backwards compatibility.
"""
def __init__(
self,
pattern: Union[str, Pattern],
handler: Any,
kwargs: Dict[str, Any] = None,
name: str = None,
) -> None:
"""Parameters:
* ``pattern``: Regular expression to be matched. Any capturing
groups in the regex will be passed in to the handler's
get/post/etc methods as arguments (by keyword if named, by
position if unnamed. Named and unnamed capturing groups
may not be mixed in the same rule).
* ``handler``: `~.web.RequestHandler` subclass to be invoked.
* ``kwargs`` (optional): A dictionary of additional arguments
to be passed to the handler's constructor.
* ``name`` (optional): A name for this handler. Used by
`~.web.Application.reverse_url`.
"""
matcher = PathMatches(pattern)
super(URLSpec, self).__init__(matcher, handler, kwargs, name)
self.regex = matcher.regex
self.handler_class = self.target
self.kwargs = kwargs
def __repr__(self) -> str:
return "%s(%r, %s, kwargs=%r, name=%r)" % (
self.__class__.__name__,
self.regex.pattern,
self.handler_class,
self.kwargs,
self.name,
)
@overload
def _unquote_or_none(s: str) -> bytes:
pass
@overload # noqa: F811
def _unquote_or_none(s: None) -> None:
pass
def _unquote_or_none(s: Optional[str]) -> Optional[bytes]: # noqa: F811
"""None-safe wrapper around url_unescape to handle unmatched optional
groups correctly.
Note that args are passed as bytes so the handler can decide what
encoding to use.
"""
if s is None:
return s
return url_unescape(s, encoding=None, plus=False)

View File

@@ -0,0 +1,692 @@
from tornado.escape import _unicode
from tornado import gen
from tornado.httpclient import (
HTTPResponse,
HTTPError,
AsyncHTTPClient,
main,
_RequestProxy,
HTTPRequest,
)
from tornado import httputil
from tornado.http1connection import HTTP1Connection, HTTP1ConnectionParameters
from tornado.ioloop import IOLoop
from tornado.iostream import StreamClosedError, IOStream
from tornado.netutil import (
Resolver,
OverrideResolver,
_client_ssl_defaults,
is_valid_ip,
)
from tornado.log import gen_log
from tornado.tcpclient import TCPClient
import base64
import collections
import copy
import functools
import re
import socket
import ssl
import sys
import time
from io import BytesIO
import urllib.parse
from typing import Dict, Any, Callable, Optional, Type, Union
from types import TracebackType
import typing
if typing.TYPE_CHECKING:
from typing import Deque, Tuple, List # noqa: F401
class HTTPTimeoutError(HTTPError):
"""Error raised by SimpleAsyncHTTPClient on timeout.
For historical reasons, this is a subclass of `.HTTPClientError`
which simulates a response code of 599.
.. versionadded:: 5.1
"""
def __init__(self, message: str) -> None:
super(HTTPTimeoutError, self).__init__(599, message=message)
def __str__(self) -> str:
return self.message or "Timeout"
class HTTPStreamClosedError(HTTPError):
"""Error raised by SimpleAsyncHTTPClient when the underlying stream is closed.
When a more specific exception is available (such as `ConnectionResetError`),
it may be raised instead of this one.
For historical reasons, this is a subclass of `.HTTPClientError`
which simulates a response code of 599.
.. versionadded:: 5.1
"""
def __init__(self, message: str) -> None:
super(HTTPStreamClosedError, self).__init__(599, message=message)
def __str__(self) -> str:
return self.message or "Stream closed"
class SimpleAsyncHTTPClient(AsyncHTTPClient):
"""Non-blocking HTTP client with no external dependencies.
This class implements an HTTP 1.1 client on top of Tornado's IOStreams.
Some features found in the curl-based AsyncHTTPClient are not yet
supported. In particular, proxies are not supported, connections
are not reused, and callers cannot select the network interface to be
used.
"""
def initialize( # type: ignore
self,
max_clients: int = 10,
hostname_mapping: Dict[str, str] = None,
max_buffer_size: int = 104857600,
resolver: Resolver = None,
defaults: Dict[str, Any] = None,
max_header_size: int = None,
max_body_size: int = None,
) -> None:
"""Creates a AsyncHTTPClient.
Only a single AsyncHTTPClient instance exists per IOLoop
in order to provide limitations on the number of pending connections.
``force_instance=True`` may be used to suppress this behavior.
Note that because of this implicit reuse, unless ``force_instance``
is used, only the first call to the constructor actually uses
its arguments. It is recommended to use the ``configure`` method
instead of the constructor to ensure that arguments take effect.
``max_clients`` is the number of concurrent requests that can be
in progress; when this limit is reached additional requests will be
queued. Note that time spent waiting in this queue still counts
against the ``request_timeout``.
``hostname_mapping`` is a dictionary mapping hostnames to IP addresses.
It can be used to make local DNS changes when modifying system-wide
settings like ``/etc/hosts`` is not possible or desirable (e.g. in
unittests).
``max_buffer_size`` (default 100MB) is the number of bytes
that can be read into memory at once. ``max_body_size``
(defaults to ``max_buffer_size``) is the largest response body
that the client will accept. Without a
``streaming_callback``, the smaller of these two limits
applies; with a ``streaming_callback`` only ``max_body_size``
does.
.. versionchanged:: 4.2
Added the ``max_body_size`` argument.
"""
super(SimpleAsyncHTTPClient, self).initialize(defaults=defaults)
self.max_clients = max_clients
self.queue = (
collections.deque()
) # type: Deque[Tuple[object, HTTPRequest, Callable[[HTTPResponse], None]]]
self.active = (
{}
) # type: Dict[object, Tuple[HTTPRequest, Callable[[HTTPResponse], None]]]
self.waiting = (
{}
) # type: Dict[object, Tuple[HTTPRequest, Callable[[HTTPResponse], None], object]]
self.max_buffer_size = max_buffer_size
self.max_header_size = max_header_size
self.max_body_size = max_body_size
# TCPClient could create a Resolver for us, but we have to do it
# ourselves to support hostname_mapping.
if resolver:
self.resolver = resolver
self.own_resolver = False
else:
self.resolver = Resolver()
self.own_resolver = True
if hostname_mapping is not None:
self.resolver = OverrideResolver(
resolver=self.resolver, mapping=hostname_mapping
)
self.tcp_client = TCPClient(resolver=self.resolver)
def close(self) -> None:
super(SimpleAsyncHTTPClient, self).close()
if self.own_resolver:
self.resolver.close()
self.tcp_client.close()
def fetch_impl(
self, request: HTTPRequest, callback: Callable[[HTTPResponse], None]
) -> None:
key = object()
self.queue.append((key, request, callback))
if not len(self.active) < self.max_clients:
assert request.connect_timeout is not None
assert request.request_timeout is not None
timeout_handle = self.io_loop.add_timeout(
self.io_loop.time()
+ min(request.connect_timeout, request.request_timeout),
functools.partial(self._on_timeout, key, "in request queue"),
)
else:
timeout_handle = None
self.waiting[key] = (request, callback, timeout_handle)
self._process_queue()
if self.queue:
gen_log.debug(
"max_clients limit reached, request queued. "
"%d active, %d queued requests." % (len(self.active), len(self.queue))
)
def _process_queue(self) -> None:
while self.queue and len(self.active) < self.max_clients:
key, request, callback = self.queue.popleft()
if key not in self.waiting:
continue
self._remove_timeout(key)
self.active[key] = (request, callback)
release_callback = functools.partial(self._release_fetch, key)
self._handle_request(request, release_callback, callback)
def _connection_class(self) -> type:
return _HTTPConnection
def _handle_request(
self,
request: HTTPRequest,
release_callback: Callable[[], None],
final_callback: Callable[[HTTPResponse], None],
) -> None:
self._connection_class()(
self,
request,
release_callback,
final_callback,
self.max_buffer_size,
self.tcp_client,
self.max_header_size,
self.max_body_size,
)
def _release_fetch(self, key: object) -> None:
del self.active[key]
self._process_queue()
def _remove_timeout(self, key: object) -> None:
if key in self.waiting:
request, callback, timeout_handle = self.waiting[key]
if timeout_handle is not None:
self.io_loop.remove_timeout(timeout_handle)
del self.waiting[key]
def _on_timeout(self, key: object, info: str = None) -> None:
"""Timeout callback of request.
Construct a timeout HTTPResponse when a timeout occurs.
:arg object key: A simple object to mark the request.
:info string key: More detailed timeout information.
"""
request, callback, timeout_handle = self.waiting[key]
self.queue.remove((key, request, callback))
error_message = "Timeout {0}".format(info) if info else "Timeout"
timeout_response = HTTPResponse(
request,
599,
error=HTTPTimeoutError(error_message),
request_time=self.io_loop.time() - request.start_time,
)
self.io_loop.add_callback(callback, timeout_response)
del self.waiting[key]
class _HTTPConnection(httputil.HTTPMessageDelegate):
_SUPPORTED_METHODS = set(
["GET", "HEAD", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"]
)
def __init__(
self,
client: Optional[SimpleAsyncHTTPClient],
request: HTTPRequest,
release_callback: Callable[[], None],
final_callback: Callable[[HTTPResponse], None],
max_buffer_size: int,
tcp_client: TCPClient,
max_header_size: int,
max_body_size: int,
) -> None:
self.io_loop = IOLoop.current()
self.start_time = self.io_loop.time()
self.start_wall_time = time.time()
self.client = client
self.request = request
self.release_callback = release_callback
self.final_callback = final_callback
self.max_buffer_size = max_buffer_size
self.tcp_client = tcp_client
self.max_header_size = max_header_size
self.max_body_size = max_body_size
self.code = None # type: Optional[int]
self.headers = None # type: Optional[httputil.HTTPHeaders]
self.chunks = [] # type: List[bytes]
self._decompressor = None
# Timeout handle returned by IOLoop.add_timeout
self._timeout = None # type: object
self._sockaddr = None
IOLoop.current().add_future(
gen.convert_yielded(self.run()), lambda f: f.result()
)
async def run(self) -> None:
try:
self.parsed = urllib.parse.urlsplit(_unicode(self.request.url))
if self.parsed.scheme not in ("http", "https"):
raise ValueError("Unsupported url scheme: %s" % self.request.url)
# urlsplit results have hostname and port results, but they
# didn't support ipv6 literals until python 2.7.
netloc = self.parsed.netloc
if "@" in netloc:
userpass, _, netloc = netloc.rpartition("@")
host, port = httputil.split_host_and_port(netloc)
if port is None:
port = 443 if self.parsed.scheme == "https" else 80
if re.match(r"^\[.*\]$", host):
# raw ipv6 addresses in urls are enclosed in brackets
host = host[1:-1]
self.parsed_hostname = host # save final host for _on_connect
if self.request.allow_ipv6 is False:
af = socket.AF_INET
else:
af = socket.AF_UNSPEC
ssl_options = self._get_ssl_options(self.parsed.scheme)
source_ip = None
if self.request.network_interface:
if is_valid_ip(self.request.network_interface):
source_ip = self.request.network_interface
else:
raise ValueError(
"Unrecognized IPv4 or IPv6 address for network_interface, got %r"
% (self.request.network_interface,)
)
timeout = min(self.request.connect_timeout, self.request.request_timeout)
if timeout:
self._timeout = self.io_loop.add_timeout(
self.start_time + timeout,
functools.partial(self._on_timeout, "while connecting"),
)
stream = await self.tcp_client.connect(
host,
port,
af=af,
ssl_options=ssl_options,
max_buffer_size=self.max_buffer_size,
source_ip=source_ip,
)
if self.final_callback is None:
# final_callback is cleared if we've hit our timeout.
stream.close()
return
self.stream = stream
self.stream.set_close_callback(self.on_connection_close)
self._remove_timeout()
if self.final_callback is None:
return
if self.request.request_timeout:
self._timeout = self.io_loop.add_timeout(
self.start_time + self.request.request_timeout,
functools.partial(self._on_timeout, "during request"),
)
if (
self.request.method not in self._SUPPORTED_METHODS
and not self.request.allow_nonstandard_methods
):
raise KeyError("unknown method %s" % self.request.method)
for key in (
"proxy_host",
"proxy_port",
"proxy_username",
"proxy_password",
"proxy_auth_mode",
):
if getattr(self.request, key, None):
raise NotImplementedError("%s not supported" % key)
if "Connection" not in self.request.headers:
self.request.headers["Connection"] = "close"
if "Host" not in self.request.headers:
if "@" in self.parsed.netloc:
self.request.headers["Host"] = self.parsed.netloc.rpartition(
"@"
)[-1]
else:
self.request.headers["Host"] = self.parsed.netloc
username, password = None, None
if self.parsed.username is not None:
username, password = self.parsed.username, self.parsed.password
elif self.request.auth_username is not None:
username = self.request.auth_username
password = self.request.auth_password or ""
if username is not None:
assert password is not None
if self.request.auth_mode not in (None, "basic"):
raise ValueError(
"unsupported auth_mode %s", self.request.auth_mode
)
self.request.headers["Authorization"] = "Basic " + _unicode(
base64.b64encode(
httputil.encode_username_password(username, password)
)
)
if self.request.user_agent:
self.request.headers["User-Agent"] = self.request.user_agent
if not self.request.allow_nonstandard_methods:
# Some HTTP methods nearly always have bodies while others
# almost never do. Fail in this case unless the user has
# opted out of sanity checks with allow_nonstandard_methods.
body_expected = self.request.method in ("POST", "PATCH", "PUT")
body_present = (
self.request.body is not None
or self.request.body_producer is not None
)
if (body_expected and not body_present) or (
body_present and not body_expected
):
raise ValueError(
"Body must %sbe None for method %s (unless "
"allow_nonstandard_methods is true)"
% ("not " if body_expected else "", self.request.method)
)
if self.request.expect_100_continue:
self.request.headers["Expect"] = "100-continue"
if self.request.body is not None:
# When body_producer is used the caller is responsible for
# setting Content-Length (or else chunked encoding will be used).
self.request.headers["Content-Length"] = str(len(self.request.body))
if (
self.request.method == "POST"
and "Content-Type" not in self.request.headers
):
self.request.headers[
"Content-Type"
] = "application/x-www-form-urlencoded"
if self.request.decompress_response:
self.request.headers["Accept-Encoding"] = "gzip"
req_path = (self.parsed.path or "/") + (
("?" + self.parsed.query) if self.parsed.query else ""
)
self.connection = self._create_connection(stream)
start_line = httputil.RequestStartLine(
self.request.method, req_path, ""
)
self.connection.write_headers(start_line, self.request.headers)
if self.request.expect_100_continue:
await self.connection.read_response(self)
else:
await self._write_body(True)
except Exception:
if not self._handle_exception(*sys.exc_info()):
raise
def _get_ssl_options(
self, scheme: str
) -> Union[None, Dict[str, Any], ssl.SSLContext]:
if scheme == "https":
if self.request.ssl_options is not None:
return self.request.ssl_options
# If we are using the defaults, don't construct a
# new SSLContext.
if (
self.request.validate_cert
and self.request.ca_certs is None
and self.request.client_cert is None
and self.request.client_key is None
):
return _client_ssl_defaults
ssl_ctx = ssl.create_default_context(
ssl.Purpose.SERVER_AUTH, cafile=self.request.ca_certs
)
if not self.request.validate_cert:
ssl_ctx.check_hostname = False
ssl_ctx.verify_mode = ssl.CERT_NONE
if self.request.client_cert is not None:
ssl_ctx.load_cert_chain(
self.request.client_cert, self.request.client_key
)
if hasattr(ssl, "OP_NO_COMPRESSION"):
# See netutil.ssl_options_to_context
ssl_ctx.options |= ssl.OP_NO_COMPRESSION
return ssl_ctx
return None
def _on_timeout(self, info: str = None) -> None:
"""Timeout callback of _HTTPConnection instance.
Raise a `HTTPTimeoutError` when a timeout occurs.
:info string key: More detailed timeout information.
"""
self._timeout = None
error_message = "Timeout {0}".format(info) if info else "Timeout"
if self.final_callback is not None:
self._handle_exception(
HTTPTimeoutError, HTTPTimeoutError(error_message), None
)
def _remove_timeout(self) -> None:
if self._timeout is not None:
self.io_loop.remove_timeout(self._timeout)
self._timeout = None
def _create_connection(self, stream: IOStream) -> HTTP1Connection:
stream.set_nodelay(True)
connection = HTTP1Connection(
stream,
True,
HTTP1ConnectionParameters(
no_keep_alive=True,
max_header_size=self.max_header_size,
max_body_size=self.max_body_size,
decompress=bool(self.request.decompress_response),
),
self._sockaddr,
)
return connection
async def _write_body(self, start_read: bool) -> None:
if self.request.body is not None:
self.connection.write(self.request.body)
elif self.request.body_producer is not None:
fut = self.request.body_producer(self.connection.write)
if fut is not None:
await fut
self.connection.finish()
if start_read:
try:
await self.connection.read_response(self)
except StreamClosedError:
if not self._handle_exception(*sys.exc_info()):
raise
def _release(self) -> None:
if self.release_callback is not None:
release_callback = self.release_callback
self.release_callback = None # type: ignore
release_callback()
def _run_callback(self, response: HTTPResponse) -> None:
self._release()
if self.final_callback is not None:
final_callback = self.final_callback
self.final_callback = None # type: ignore
self.io_loop.add_callback(final_callback, response)
def _handle_exception(
self,
typ: "Optional[Type[BaseException]]",
value: Optional[BaseException],
tb: Optional[TracebackType],
) -> bool:
if self.final_callback:
self._remove_timeout()
if isinstance(value, StreamClosedError):
if value.real_error is None:
value = HTTPStreamClosedError("Stream closed")
else:
value = value.real_error
self._run_callback(
HTTPResponse(
self.request,
599,
error=value,
request_time=self.io_loop.time() - self.start_time,
start_time=self.start_wall_time,
)
)
if hasattr(self, "stream"):
# TODO: this may cause a StreamClosedError to be raised
# by the connection's Future. Should we cancel the
# connection more gracefully?
self.stream.close()
return True
else:
# If our callback has already been called, we are probably
# catching an exception that is not caused by us but rather
# some child of our callback. Rather than drop it on the floor,
# pass it along, unless it's just the stream being closed.
return isinstance(value, StreamClosedError)
def on_connection_close(self) -> None:
if self.final_callback is not None:
message = "Connection closed"
if self.stream.error:
raise self.stream.error
try:
raise HTTPStreamClosedError(message)
except HTTPStreamClosedError:
self._handle_exception(*sys.exc_info())
async def headers_received(
self,
first_line: Union[httputil.ResponseStartLine, httputil.RequestStartLine],
headers: httputil.HTTPHeaders,
) -> None:
assert isinstance(first_line, httputil.ResponseStartLine)
if self.request.expect_100_continue and first_line.code == 100:
await self._write_body(False)
return
self.code = first_line.code
self.reason = first_line.reason
self.headers = headers
if self._should_follow_redirect():
return
if self.request.header_callback is not None:
# Reassemble the start line.
self.request.header_callback("%s %s %s\r\n" % first_line)
for k, v in self.headers.get_all():
self.request.header_callback("%s: %s\r\n" % (k, v))
self.request.header_callback("\r\n")
def _should_follow_redirect(self) -> bool:
if self.request.follow_redirects:
assert self.request.max_redirects is not None
return (
self.code in (301, 302, 303, 307, 308)
and self.request.max_redirects > 0
and self.headers is not None
and self.headers.get("Location") is not None
)
return False
def finish(self) -> None:
assert self.code is not None
data = b"".join(self.chunks)
self._remove_timeout()
original_request = getattr(self.request, "original_request", self.request)
if self._should_follow_redirect():
assert isinstance(self.request, _RequestProxy)
new_request = copy.copy(self.request.request)
new_request.url = urllib.parse.urljoin(
self.request.url, self.headers["Location"]
)
new_request.max_redirects = self.request.max_redirects - 1
del new_request.headers["Host"]
# https://tools.ietf.org/html/rfc7231#section-6.4
#
# The original HTTP spec said that after a 301 or 302
# redirect, the request method should be preserved.
# However, browsers implemented this by changing the
# method to GET, and the behavior stuck. 303 redirects
# always specified this POST-to-GET behavior (arguably 303
# redirects should change *all* requests to GET, but
# libcurl only does this for POST so we follow their
# example).
if self.code in (301, 302, 303) and self.request.method == "POST":
new_request.method = "GET"
new_request.body = None
for h in [
"Content-Length",
"Content-Type",
"Content-Encoding",
"Transfer-Encoding",
]:
try:
del self.request.headers[h]
except KeyError:
pass
new_request.original_request = original_request
final_callback = self.final_callback
self.final_callback = None
self._release()
fut = self.client.fetch(new_request, raise_error=False)
fut.add_done_callback(lambda f: final_callback(f.result()))
self._on_end_request()
return
if self.request.streaming_callback:
buffer = BytesIO()
else:
buffer = BytesIO(data) # TODO: don't require one big string?
response = HTTPResponse(
original_request,
self.code,
reason=getattr(self, "reason", None),
headers=self.headers,
request_time=self.io_loop.time() - self.start_time,
start_time=self.start_wall_time,
buffer=buffer,
effective_url=self.request.url,
)
self._run_callback(response)
self._on_end_request()
def _on_end_request(self) -> None:
self.stream.close()
def data_received(self, chunk: bytes) -> None:
if self._should_follow_redirect():
# We're going to follow a redirect so just discard the body.
return
if self.request.streaming_callback is not None:
self.request.streaming_callback(chunk)
else:
self.chunks.append(chunk)
if __name__ == "__main__":
AsyncHTTPClient.configure(SimpleAsyncHTTPClient)
main()

View File

@@ -0,0 +1,334 @@
#
# Copyright 2014 Facebook
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""A non-blocking TCP connection factory.
"""
import functools
import socket
import numbers
import datetime
import ssl
from tornado.concurrent import Future, future_add_done_callback
from tornado.ioloop import IOLoop
from tornado.iostream import IOStream
from tornado import gen
from tornado.netutil import Resolver
from tornado.platform.auto import set_close_exec
from tornado.gen import TimeoutError
import typing
from typing import Any, Union, Dict, Tuple, List, Callable, Iterator
if typing.TYPE_CHECKING:
from typing import Optional, Set # noqa: F401
_INITIAL_CONNECT_TIMEOUT = 0.3
class _Connector(object):
"""A stateless implementation of the "Happy Eyeballs" algorithm.
"Happy Eyeballs" is documented in RFC6555 as the recommended practice
for when both IPv4 and IPv6 addresses are available.
In this implementation, we partition the addresses by family, and
make the first connection attempt to whichever address was
returned first by ``getaddrinfo``. If that connection fails or
times out, we begin a connection in parallel to the first address
of the other family. If there are additional failures we retry
with other addresses, keeping one connection attempt per family
in flight at a time.
http://tools.ietf.org/html/rfc6555
"""
def __init__(
self,
addrinfo: List[Tuple],
connect: Callable[
[socket.AddressFamily, Tuple], Tuple[IOStream, "Future[IOStream]"]
],
) -> None:
self.io_loop = IOLoop.current()
self.connect = connect
self.future = (
Future()
) # type: Future[Tuple[socket.AddressFamily, Any, IOStream]]
self.timeout = None # type: Optional[object]
self.connect_timeout = None # type: Optional[object]
self.last_error = None # type: Optional[Exception]
self.remaining = len(addrinfo)
self.primary_addrs, self.secondary_addrs = self.split(addrinfo)
self.streams = set() # type: Set[IOStream]
@staticmethod
def split(
addrinfo: List[Tuple]
) -> Tuple[
List[Tuple[socket.AddressFamily, Tuple]],
List[Tuple[socket.AddressFamily, Tuple]],
]:
"""Partition the ``addrinfo`` list by address family.
Returns two lists. The first list contains the first entry from
``addrinfo`` and all others with the same family, and the
second list contains all other addresses (normally one list will
be AF_INET and the other AF_INET6, although non-standard resolvers
may return additional families).
"""
primary = []
secondary = []
primary_af = addrinfo[0][0]
for af, addr in addrinfo:
if af == primary_af:
primary.append((af, addr))
else:
secondary.append((af, addr))
return primary, secondary
def start(
self,
timeout: float = _INITIAL_CONNECT_TIMEOUT,
connect_timeout: Union[float, datetime.timedelta] = None,
) -> "Future[Tuple[socket.AddressFamily, Any, IOStream]]":
self.try_connect(iter(self.primary_addrs))
self.set_timeout(timeout)
if connect_timeout is not None:
self.set_connect_timeout(connect_timeout)
return self.future
def try_connect(self, addrs: Iterator[Tuple[socket.AddressFamily, Tuple]]) -> None:
try:
af, addr = next(addrs)
except StopIteration:
# We've reached the end of our queue, but the other queue
# might still be working. Send a final error on the future
# only when both queues are finished.
if self.remaining == 0 and not self.future.done():
self.future.set_exception(
self.last_error or IOError("connection failed")
)
return
stream, future = self.connect(af, addr)
self.streams.add(stream)
future_add_done_callback(
future, functools.partial(self.on_connect_done, addrs, af, addr)
)
def on_connect_done(
self,
addrs: Iterator[Tuple[socket.AddressFamily, Tuple]],
af: socket.AddressFamily,
addr: Tuple,
future: "Future[IOStream]",
) -> None:
self.remaining -= 1
try:
stream = future.result()
except Exception as e:
if self.future.done():
return
# Error: try again (but remember what happened so we have an
# error to raise in the end)
self.last_error = e
self.try_connect(addrs)
if self.timeout is not None:
# If the first attempt failed, don't wait for the
# timeout to try an address from the secondary queue.
self.io_loop.remove_timeout(self.timeout)
self.on_timeout()
return
self.clear_timeouts()
if self.future.done():
# This is a late arrival; just drop it.
stream.close()
else:
self.streams.discard(stream)
self.future.set_result((af, addr, stream))
self.close_streams()
def set_timeout(self, timeout: float) -> None:
self.timeout = self.io_loop.add_timeout(
self.io_loop.time() + timeout, self.on_timeout
)
def on_timeout(self) -> None:
self.timeout = None
if not self.future.done():
self.try_connect(iter(self.secondary_addrs))
def clear_timeout(self) -> None:
if self.timeout is not None:
self.io_loop.remove_timeout(self.timeout)
def set_connect_timeout(
self, connect_timeout: Union[float, datetime.timedelta]
) -> None:
self.connect_timeout = self.io_loop.add_timeout(
connect_timeout, self.on_connect_timeout
)
def on_connect_timeout(self) -> None:
if not self.future.done():
self.future.set_exception(TimeoutError())
self.close_streams()
def clear_timeouts(self) -> None:
if self.timeout is not None:
self.io_loop.remove_timeout(self.timeout)
if self.connect_timeout is not None:
self.io_loop.remove_timeout(self.connect_timeout)
def close_streams(self) -> None:
for stream in self.streams:
stream.close()
class TCPClient(object):
"""A non-blocking TCP connection factory.
.. versionchanged:: 5.0
The ``io_loop`` argument (deprecated since version 4.1) has been removed.
"""
def __init__(self, resolver: Resolver = None) -> None:
if resolver is not None:
self.resolver = resolver
self._own_resolver = False
else:
self.resolver = Resolver()
self._own_resolver = True
def close(self) -> None:
if self._own_resolver:
self.resolver.close()
async def connect(
self,
host: str,
port: int,
af: socket.AddressFamily = socket.AF_UNSPEC,
ssl_options: Union[Dict[str, Any], ssl.SSLContext] = None,
max_buffer_size: int = None,
source_ip: str = None,
source_port: int = None,
timeout: Union[float, datetime.timedelta] = None,
) -> IOStream:
"""Connect to the given host and port.
Asynchronously returns an `.IOStream` (or `.SSLIOStream` if
``ssl_options`` is not None).
Using the ``source_ip`` kwarg, one can specify the source
IP address to use when establishing the connection.
In case the user needs to resolve and
use a specific interface, it has to be handled outside
of Tornado as this depends very much on the platform.
Raises `TimeoutError` if the input future does not complete before
``timeout``, which may be specified in any form allowed by
`.IOLoop.add_timeout` (i.e. a `datetime.timedelta` or an absolute time
relative to `.IOLoop.time`)
Similarly, when the user requires a certain source port, it can
be specified using the ``source_port`` arg.
.. versionchanged:: 4.5
Added the ``source_ip`` and ``source_port`` arguments.
.. versionchanged:: 5.0
Added the ``timeout`` argument.
"""
if timeout is not None:
if isinstance(timeout, numbers.Real):
timeout = IOLoop.current().time() + timeout
elif isinstance(timeout, datetime.timedelta):
timeout = IOLoop.current().time() + timeout.total_seconds()
else:
raise TypeError("Unsupported timeout %r" % timeout)
if timeout is not None:
addrinfo = await gen.with_timeout(
timeout, self.resolver.resolve(host, port, af)
)
else:
addrinfo = await self.resolver.resolve(host, port, af)
connector = _Connector(
addrinfo,
functools.partial(
self._create_stream,
max_buffer_size,
source_ip=source_ip,
source_port=source_port,
),
)
af, addr, stream = await connector.start(connect_timeout=timeout)
# TODO: For better performance we could cache the (af, addr)
# information here and re-use it on subsequent connections to
# the same host. (http://tools.ietf.org/html/rfc6555#section-4.2)
if ssl_options is not None:
if timeout is not None:
stream = await gen.with_timeout(
timeout,
stream.start_tls(
False, ssl_options=ssl_options, server_hostname=host
),
)
else:
stream = await stream.start_tls(
False, ssl_options=ssl_options, server_hostname=host
)
return stream
def _create_stream(
self,
max_buffer_size: int,
af: socket.AddressFamily,
addr: Tuple,
source_ip: str = None,
source_port: int = None,
) -> Tuple[IOStream, "Future[IOStream]"]:
# Always connect in plaintext; we'll convert to ssl if necessary
# after one connection has completed.
source_port_bind = source_port if isinstance(source_port, int) else 0
source_ip_bind = source_ip
if source_port_bind and not source_ip:
# User required a specific port, but did not specify
# a certain source IP, will bind to the default loopback.
source_ip_bind = "::1" if af == socket.AF_INET6 else "127.0.0.1"
# Trying to use the same address family as the requested af socket:
# - 127.0.0.1 for IPv4
# - ::1 for IPv6
socket_obj = socket.socket(af)
set_close_exec(socket_obj.fileno())
if source_port_bind or source_ip_bind:
# If the user requires binding also to a specific IP/port.
try:
socket_obj.bind((source_ip_bind, source_port_bind))
except socket.error:
socket_obj.close()
# Fail loudly if unable to use the IP/port.
raise
try:
stream = IOStream(socket_obj, max_buffer_size=max_buffer_size)
except socket.error as e:
fu = Future() # type: Future[IOStream]
fu.set_exception(e)
return stream, fu
else:
return stream, stream.connect(addr)

View File

@@ -0,0 +1,330 @@
#
# Copyright 2011 Facebook
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""A non-blocking, single-threaded TCP server."""
import errno
import os
import socket
import ssl
from tornado import gen
from tornado.log import app_log
from tornado.ioloop import IOLoop
from tornado.iostream import IOStream, SSLIOStream
from tornado.netutil import bind_sockets, add_accept_handler, ssl_wrap_socket
from tornado import process
from tornado.util import errno_from_exception
import typing
from typing import Union, Dict, Any, Iterable, Optional, Awaitable
if typing.TYPE_CHECKING:
from typing import Callable, List # noqa: F401
class TCPServer(object):
r"""A non-blocking, single-threaded TCP server.
To use `TCPServer`, define a subclass which overrides the `handle_stream`
method. For example, a simple echo server could be defined like this::
from tornado.tcpserver import TCPServer
from tornado.iostream import StreamClosedError
from tornado import gen
class EchoServer(TCPServer):
async def handle_stream(self, stream, address):
while True:
try:
data = await stream.read_until(b"\n")
await stream.write(data)
except StreamClosedError:
break
To make this server serve SSL traffic, send the ``ssl_options`` keyword
argument with an `ssl.SSLContext` object. For compatibility with older
versions of Python ``ssl_options`` may also be a dictionary of keyword
arguments for the `ssl.wrap_socket` method.::
ssl_ctx = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
ssl_ctx.load_cert_chain(os.path.join(data_dir, "mydomain.crt"),
os.path.join(data_dir, "mydomain.key"))
TCPServer(ssl_options=ssl_ctx)
`TCPServer` initialization follows one of three patterns:
1. `listen`: simple single-process::
server = TCPServer()
server.listen(8888)
IOLoop.current().start()
2. `bind`/`start`: simple multi-process::
server = TCPServer()
server.bind(8888)
server.start(0) # Forks multiple sub-processes
IOLoop.current().start()
When using this interface, an `.IOLoop` must *not* be passed
to the `TCPServer` constructor. `start` will always start
the server on the default singleton `.IOLoop`.
3. `add_sockets`: advanced multi-process::
sockets = bind_sockets(8888)
tornado.process.fork_processes(0)
server = TCPServer()
server.add_sockets(sockets)
IOLoop.current().start()
The `add_sockets` interface is more complicated, but it can be
used with `tornado.process.fork_processes` to give you more
flexibility in when the fork happens. `add_sockets` can
also be used in single-process servers if you want to create
your listening sockets in some way other than
`~tornado.netutil.bind_sockets`.
.. versionadded:: 3.1
The ``max_buffer_size`` argument.
.. versionchanged:: 5.0
The ``io_loop`` argument has been removed.
"""
def __init__(
self,
ssl_options: Union[Dict[str, Any], ssl.SSLContext] = None,
max_buffer_size: int = None,
read_chunk_size: int = None,
) -> None:
self.ssl_options = ssl_options
self._sockets = {} # type: Dict[int, socket.socket]
self._handlers = {} # type: Dict[int, Callable[[], None]]
self._pending_sockets = [] # type: List[socket.socket]
self._started = False
self._stopped = False
self.max_buffer_size = max_buffer_size
self.read_chunk_size = read_chunk_size
# Verify the SSL options. Otherwise we don't get errors until clients
# connect. This doesn't verify that the keys are legitimate, but
# the SSL module doesn't do that until there is a connected socket
# which seems like too much work
if self.ssl_options is not None and isinstance(self.ssl_options, dict):
# Only certfile is required: it can contain both keys
if "certfile" not in self.ssl_options:
raise KeyError('missing key "certfile" in ssl_options')
if not os.path.exists(self.ssl_options["certfile"]):
raise ValueError(
'certfile "%s" does not exist' % self.ssl_options["certfile"]
)
if "keyfile" in self.ssl_options and not os.path.exists(
self.ssl_options["keyfile"]
):
raise ValueError(
'keyfile "%s" does not exist' % self.ssl_options["keyfile"]
)
def listen(self, port: int, address: str = "") -> None:
"""Starts accepting connections on the given port.
This method may be called more than once to listen on multiple ports.
`listen` takes effect immediately; it is not necessary to call
`TCPServer.start` afterwards. It is, however, necessary to start
the `.IOLoop`.
"""
sockets = bind_sockets(port, address=address)
self.add_sockets(sockets)
def add_sockets(self, sockets: Iterable[socket.socket]) -> None:
"""Makes this server start accepting connections on the given sockets.
The ``sockets`` parameter is a list of socket objects such as
those returned by `~tornado.netutil.bind_sockets`.
`add_sockets` is typically used in combination with that
method and `tornado.process.fork_processes` to provide greater
control over the initialization of a multi-process server.
"""
for sock in sockets:
self._sockets[sock.fileno()] = sock
self._handlers[sock.fileno()] = add_accept_handler(
sock, self._handle_connection
)
def add_socket(self, socket: socket.socket) -> None:
"""Singular version of `add_sockets`. Takes a single socket object."""
self.add_sockets([socket])
def bind(
self,
port: int,
address: str = None,
family: socket.AddressFamily = socket.AF_UNSPEC,
backlog: int = 128,
reuse_port: bool = False,
) -> None:
"""Binds this server to the given port on the given address.
To start the server, call `start`. If you want to run this server
in a single process, you can call `listen` as a shortcut to the
sequence of `bind` and `start` calls.
Address may be either an IP address or hostname. If it's a hostname,
the server will listen on all IP addresses associated with the
name. Address may be an empty string or None to listen on all
available interfaces. Family may be set to either `socket.AF_INET`
or `socket.AF_INET6` to restrict to IPv4 or IPv6 addresses, otherwise
both will be used if available.
The ``backlog`` argument has the same meaning as for
`socket.listen <socket.socket.listen>`. The ``reuse_port`` argument
has the same meaning as for `.bind_sockets`.
This method may be called multiple times prior to `start` to listen
on multiple ports or interfaces.
.. versionchanged:: 4.4
Added the ``reuse_port`` argument.
"""
sockets = bind_sockets(
port, address=address, family=family, backlog=backlog, reuse_port=reuse_port
)
if self._started:
self.add_sockets(sockets)
else:
self._pending_sockets.extend(sockets)
def start(self, num_processes: Optional[int] = 1, max_restarts: int = None) -> None:
"""Starts this server in the `.IOLoop`.
By default, we run the server in this process and do not fork any
additional child process.
If num_processes is ``None`` or <= 0, we detect the number of cores
available on this machine and fork that number of child
processes. If num_processes is given and > 1, we fork that
specific number of sub-processes.
Since we use processes and not threads, there is no shared memory
between any server code.
Note that multiple processes are not compatible with the autoreload
module (or the ``autoreload=True`` option to `tornado.web.Application`
which defaults to True when ``debug=True``).
When using multiple processes, no IOLoops can be created or
referenced until after the call to ``TCPServer.start(n)``.
The ``max_restarts`` argument is passed to `.fork_processes`.
.. versionchanged:: 6.0
Added ``max_restarts`` argument.
"""
assert not self._started
self._started = True
if num_processes != 1:
process.fork_processes(num_processes, max_restarts)
sockets = self._pending_sockets
self._pending_sockets = []
self.add_sockets(sockets)
def stop(self) -> None:
"""Stops listening for new connections.
Requests currently in progress may still continue after the
server is stopped.
"""
if self._stopped:
return
self._stopped = True
for fd, sock in self._sockets.items():
assert sock.fileno() == fd
# Unregister socket from IOLoop
self._handlers.pop(fd)()
sock.close()
def handle_stream(
self, stream: IOStream, address: tuple
) -> Optional[Awaitable[None]]:
"""Override to handle a new `.IOStream` from an incoming connection.
This method may be a coroutine; if so any exceptions it raises
asynchronously will be logged. Accepting of incoming connections
will not be blocked by this coroutine.
If this `TCPServer` is configured for SSL, ``handle_stream``
may be called before the SSL handshake has completed. Use
`.SSLIOStream.wait_for_handshake` if you need to verify the client's
certificate or use NPN/ALPN.
.. versionchanged:: 4.2
Added the option for this method to be a coroutine.
"""
raise NotImplementedError()
def _handle_connection(self, connection: socket.socket, address: Any) -> None:
if self.ssl_options is not None:
assert ssl, "Python 2.6+ and OpenSSL required for SSL"
try:
connection = ssl_wrap_socket(
connection,
self.ssl_options,
server_side=True,
do_handshake_on_connect=False,
)
except ssl.SSLError as err:
if err.args[0] == ssl.SSL_ERROR_EOF:
return connection.close()
else:
raise
except socket.error as err:
# If the connection is closed immediately after it is created
# (as in a port scan), we can get one of several errors.
# wrap_socket makes an internal call to getpeername,
# which may return either EINVAL (Mac OS X) or ENOTCONN
# (Linux). If it returns ENOTCONN, this error is
# silently swallowed by the ssl module, so we need to
# catch another error later on (AttributeError in
# SSLIOStream._do_ssl_handshake).
# To test this behavior, try nmap with the -sT flag.
# https://github.com/tornadoweb/tornado/pull/750
if errno_from_exception(err) in (errno.ECONNABORTED, errno.EINVAL):
return connection.close()
else:
raise
try:
if self.ssl_options is not None:
stream = SSLIOStream(
connection,
max_buffer_size=self.max_buffer_size,
read_chunk_size=self.read_chunk_size,
) # type: IOStream
else:
stream = IOStream(
connection,
max_buffer_size=self.max_buffer_size,
read_chunk_size=self.read_chunk_size,
)
future = self.handle_stream(stream, address)
if future is not None:
IOLoop.current().add_future(
gen.convert_yielded(future), lambda f: f.result()
)
except Exception:
app_log.error("Error in connection callback", exc_info=True)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,12 @@
"""Shim to allow python -m tornado.test.
This only works in python 2.7+.
"""
from tornado.test.runtests import all, main
# tornado.testing.main autodiscovery relies on 'all' being present in
# the main module, so import it here even though it is not used directly.
# The following line prevents a pyflakes warning.
all = all
main()

View File

@@ -0,0 +1,190 @@
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import asyncio
import unittest
from concurrent.futures import ThreadPoolExecutor
from tornado import gen
from tornado.ioloop import IOLoop
from tornado.platform.asyncio import (
AsyncIOLoop,
to_asyncio_future,
AnyThreadEventLoopPolicy,
)
from tornado.testing import AsyncTestCase, gen_test
class AsyncIOLoopTest(AsyncTestCase):
def get_new_ioloop(self):
io_loop = AsyncIOLoop()
return io_loop
def test_asyncio_callback(self):
# Basic test that the asyncio loop is set up correctly.
asyncio.get_event_loop().call_soon(self.stop)
self.wait()
@gen_test
def test_asyncio_future(self):
# Test that we can yield an asyncio future from a tornado coroutine.
# Without 'yield from', we must wrap coroutines in ensure_future,
# which was introduced during Python 3.4, deprecating the prior "async".
if hasattr(asyncio, "ensure_future"):
ensure_future = asyncio.ensure_future
else:
# async is a reserved word in Python 3.7
ensure_future = getattr(asyncio, "async")
x = yield ensure_future(
asyncio.get_event_loop().run_in_executor(None, lambda: 42)
)
self.assertEqual(x, 42)
@gen_test
def test_asyncio_yield_from(self):
@gen.coroutine
def f():
event_loop = asyncio.get_event_loop()
x = yield from event_loop.run_in_executor(None, lambda: 42)
return x
result = yield f()
self.assertEqual(result, 42)
def test_asyncio_adapter(self):
# This test demonstrates that when using the asyncio coroutine
# runner (i.e. run_until_complete), the to_asyncio_future
# adapter is needed. No adapter is needed in the other direction,
# as demonstrated by other tests in the package.
@gen.coroutine
def tornado_coroutine():
yield gen.moment
raise gen.Return(42)
async def native_coroutine_without_adapter():
return await tornado_coroutine()
async def native_coroutine_with_adapter():
return await to_asyncio_future(tornado_coroutine())
# Use the adapter, but two degrees from the tornado coroutine.
async def native_coroutine_with_adapter2():
return await to_asyncio_future(native_coroutine_without_adapter())
# Tornado supports native coroutines both with and without adapters
self.assertEqual(self.io_loop.run_sync(native_coroutine_without_adapter), 42)
self.assertEqual(self.io_loop.run_sync(native_coroutine_with_adapter), 42)
self.assertEqual(self.io_loop.run_sync(native_coroutine_with_adapter2), 42)
# Asyncio only supports coroutines that yield asyncio-compatible
# Futures (which our Future is since 5.0).
self.assertEqual(
asyncio.get_event_loop().run_until_complete(
native_coroutine_without_adapter()
),
42,
)
self.assertEqual(
asyncio.get_event_loop().run_until_complete(
native_coroutine_with_adapter()
),
42,
)
self.assertEqual(
asyncio.get_event_loop().run_until_complete(
native_coroutine_with_adapter2()
),
42,
)
class LeakTest(unittest.TestCase):
def setUp(self):
# Trigger a cleanup of the mapping so we start with a clean slate.
AsyncIOLoop().close()
# If we don't clean up after ourselves other tests may fail on
# py34.
self.orig_policy = asyncio.get_event_loop_policy()
asyncio.set_event_loop_policy(asyncio.DefaultEventLoopPolicy())
def tearDown(self):
asyncio.get_event_loop().close()
asyncio.set_event_loop_policy(self.orig_policy)
def test_ioloop_close_leak(self):
orig_count = len(IOLoop._ioloop_for_asyncio)
for i in range(10):
# Create and close an AsyncIOLoop using Tornado interfaces.
loop = AsyncIOLoop()
loop.close()
new_count = len(IOLoop._ioloop_for_asyncio) - orig_count
self.assertEqual(new_count, 0)
def test_asyncio_close_leak(self):
orig_count = len(IOLoop._ioloop_for_asyncio)
for i in range(10):
# Create and close an AsyncIOMainLoop using asyncio interfaces.
loop = asyncio.new_event_loop()
loop.call_soon(IOLoop.current)
loop.call_soon(loop.stop)
loop.run_forever()
loop.close()
new_count = len(IOLoop._ioloop_for_asyncio) - orig_count
# Because the cleanup is run on new loop creation, we have one
# dangling entry in the map (but only one).
self.assertEqual(new_count, 1)
class AnyThreadEventLoopPolicyTest(unittest.TestCase):
def setUp(self):
self.orig_policy = asyncio.get_event_loop_policy()
self.executor = ThreadPoolExecutor(1)
def tearDown(self):
asyncio.set_event_loop_policy(self.orig_policy)
self.executor.shutdown()
def get_event_loop_on_thread(self):
def get_and_close_event_loop():
"""Get the event loop. Close it if one is returned.
Returns the (closed) event loop. This is a silly thing
to do and leaves the thread in a broken state, but it's
enough for this test. Closing the loop avoids resource
leak warnings.
"""
loop = asyncio.get_event_loop()
loop.close()
return loop
future = self.executor.submit(get_and_close_event_loop)
return future.result()
def run_policy_test(self, accessor, expected_type):
# With the default policy, non-main threads don't get an event
# loop.
self.assertRaises(
(RuntimeError, AssertionError), self.executor.submit(accessor).result
)
# Set the policy and we can get a loop.
asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy())
self.assertIsInstance(self.executor.submit(accessor).result(), expected_type)
# Clean up to silence leak warnings. Always use asyncio since
# IOLoop doesn't (currently) close the underlying loop.
self.executor.submit(lambda: asyncio.get_event_loop().close()).result()
def test_asyncio_accessor(self):
self.run_policy_test(asyncio.get_event_loop, asyncio.AbstractEventLoop)
def test_tornado_accessor(self):
self.run_policy_test(IOLoop.current, IOLoop)

View File

@@ -0,0 +1,609 @@
# These tests do not currently do much to verify the correct implementation
# of the openid/oauth protocols, they just exercise the major code paths
# and ensure that it doesn't blow up (e.g. with unicode/bytes issues in
# python 3)
import unittest
from tornado.auth import (
OpenIdMixin,
OAuthMixin,
OAuth2Mixin,
GoogleOAuth2Mixin,
FacebookGraphMixin,
TwitterMixin,
)
from tornado.escape import json_decode
from tornado import gen
from tornado.httpclient import HTTPClientError
from tornado.httputil import url_concat
from tornado.log import app_log
from tornado.testing import AsyncHTTPTestCase, ExpectLog
from tornado.web import RequestHandler, Application, HTTPError
try:
from unittest import mock
except ImportError:
mock = None # type: ignore
class OpenIdClientLoginHandler(RequestHandler, OpenIdMixin):
def initialize(self, test):
self._OPENID_ENDPOINT = test.get_url("/openid/server/authenticate")
@gen.coroutine
def get(self):
if self.get_argument("openid.mode", None):
user = yield self.get_authenticated_user(
http_client=self.settings["http_client"]
)
if user is None:
raise Exception("user is None")
self.finish(user)
return
res = self.authenticate_redirect()
assert res is None
class OpenIdServerAuthenticateHandler(RequestHandler):
def post(self):
if self.get_argument("openid.mode") != "check_authentication":
raise Exception("incorrect openid.mode %r")
self.write("is_valid:true")
class OAuth1ClientLoginHandler(RequestHandler, OAuthMixin):
def initialize(self, test, version):
self._OAUTH_VERSION = version
self._OAUTH_REQUEST_TOKEN_URL = test.get_url("/oauth1/server/request_token")
self._OAUTH_AUTHORIZE_URL = test.get_url("/oauth1/server/authorize")
self._OAUTH_ACCESS_TOKEN_URL = test.get_url("/oauth1/server/access_token")
def _oauth_consumer_token(self):
return dict(key="asdf", secret="qwer")
@gen.coroutine
def get(self):
if self.get_argument("oauth_token", None):
user = yield self.get_authenticated_user(
http_client=self.settings["http_client"]
)
if user is None:
raise Exception("user is None")
self.finish(user)
return
yield self.authorize_redirect(http_client=self.settings["http_client"])
@gen.coroutine
def _oauth_get_user_future(self, access_token):
if self.get_argument("fail_in_get_user", None):
raise Exception("failing in get_user")
if access_token != dict(key="uiop", secret="5678"):
raise Exception("incorrect access token %r" % access_token)
return dict(email="foo@example.com")
class OAuth1ClientLoginCoroutineHandler(OAuth1ClientLoginHandler):
"""Replaces OAuth1ClientLoginCoroutineHandler's get() with a coroutine."""
@gen.coroutine
def get(self):
if self.get_argument("oauth_token", None):
# Ensure that any exceptions are set on the returned Future,
# not simply thrown into the surrounding StackContext.
try:
yield self.get_authenticated_user()
except Exception as e:
self.set_status(503)
self.write("got exception: %s" % e)
else:
yield self.authorize_redirect()
class OAuth1ClientRequestParametersHandler(RequestHandler, OAuthMixin):
def initialize(self, version):
self._OAUTH_VERSION = version
def _oauth_consumer_token(self):
return dict(key="asdf", secret="qwer")
def get(self):
params = self._oauth_request_parameters(
"http://www.example.com/api/asdf",
dict(key="uiop", secret="5678"),
parameters=dict(foo="bar"),
)
self.write(params)
class OAuth1ServerRequestTokenHandler(RequestHandler):
def get(self):
self.write("oauth_token=zxcv&oauth_token_secret=1234")
class OAuth1ServerAccessTokenHandler(RequestHandler):
def get(self):
self.write("oauth_token=uiop&oauth_token_secret=5678")
class OAuth2ClientLoginHandler(RequestHandler, OAuth2Mixin):
def initialize(self, test):
self._OAUTH_AUTHORIZE_URL = test.get_url("/oauth2/server/authorize")
def get(self):
res = self.authorize_redirect()
assert res is None
class FacebookClientLoginHandler(RequestHandler, FacebookGraphMixin):
def initialize(self, test):
self._OAUTH_AUTHORIZE_URL = test.get_url("/facebook/server/authorize")
self._OAUTH_ACCESS_TOKEN_URL = test.get_url("/facebook/server/access_token")
self._FACEBOOK_BASE_URL = test.get_url("/facebook/server")
@gen.coroutine
def get(self):
if self.get_argument("code", None):
user = yield self.get_authenticated_user(
redirect_uri=self.request.full_url(),
client_id=self.settings["facebook_api_key"],
client_secret=self.settings["facebook_secret"],
code=self.get_argument("code"),
)
self.write(user)
else:
yield self.authorize_redirect(
redirect_uri=self.request.full_url(),
client_id=self.settings["facebook_api_key"],
extra_params={"scope": "read_stream,offline_access"},
)
class FacebookServerAccessTokenHandler(RequestHandler):
def get(self):
self.write(dict(access_token="asdf", expires_in=3600))
class FacebookServerMeHandler(RequestHandler):
def get(self):
self.write("{}")
class TwitterClientHandler(RequestHandler, TwitterMixin):
def initialize(self, test):
self._OAUTH_REQUEST_TOKEN_URL = test.get_url("/oauth1/server/request_token")
self._OAUTH_ACCESS_TOKEN_URL = test.get_url("/twitter/server/access_token")
self._OAUTH_AUTHORIZE_URL = test.get_url("/oauth1/server/authorize")
self._OAUTH_AUTHENTICATE_URL = test.get_url("/twitter/server/authenticate")
self._TWITTER_BASE_URL = test.get_url("/twitter/api")
def get_auth_http_client(self):
return self.settings["http_client"]
class TwitterClientLoginHandler(TwitterClientHandler):
@gen.coroutine
def get(self):
if self.get_argument("oauth_token", None):
user = yield self.get_authenticated_user()
if user is None:
raise Exception("user is None")
self.finish(user)
return
yield self.authorize_redirect()
class TwitterClientAuthenticateHandler(TwitterClientHandler):
# Like TwitterClientLoginHandler, but uses authenticate_redirect
# instead of authorize_redirect.
@gen.coroutine
def get(self):
if self.get_argument("oauth_token", None):
user = yield self.get_authenticated_user()
if user is None:
raise Exception("user is None")
self.finish(user)
return
yield self.authenticate_redirect()
class TwitterClientLoginGenCoroutineHandler(TwitterClientHandler):
@gen.coroutine
def get(self):
if self.get_argument("oauth_token", None):
user = yield self.get_authenticated_user()
self.finish(user)
else:
# New style: with @gen.coroutine the result must be yielded
# or else the request will be auto-finished too soon.
yield self.authorize_redirect()
class TwitterClientShowUserHandler(TwitterClientHandler):
@gen.coroutine
def get(self):
# TODO: would be nice to go through the login flow instead of
# cheating with a hard-coded access token.
try:
response = yield self.twitter_request(
"/users/show/%s" % self.get_argument("name"),
access_token=dict(key="hjkl", secret="vbnm"),
)
except HTTPClientError:
# TODO(bdarnell): Should we catch HTTP errors and
# transform some of them (like 403s) into AuthError?
self.set_status(500)
self.finish("error from twitter request")
else:
self.finish(response)
class TwitterServerAccessTokenHandler(RequestHandler):
def get(self):
self.write("oauth_token=hjkl&oauth_token_secret=vbnm&screen_name=foo")
class TwitterServerShowUserHandler(RequestHandler):
def get(self, screen_name):
if screen_name == "error":
raise HTTPError(500)
assert "oauth_nonce" in self.request.arguments
assert "oauth_timestamp" in self.request.arguments
assert "oauth_signature" in self.request.arguments
assert self.get_argument("oauth_consumer_key") == "test_twitter_consumer_key"
assert self.get_argument("oauth_signature_method") == "HMAC-SHA1"
assert self.get_argument("oauth_version") == "1.0"
assert self.get_argument("oauth_token") == "hjkl"
self.write(dict(screen_name=screen_name, name=screen_name.capitalize()))
class TwitterServerVerifyCredentialsHandler(RequestHandler):
def get(self):
assert "oauth_nonce" in self.request.arguments
assert "oauth_timestamp" in self.request.arguments
assert "oauth_signature" in self.request.arguments
assert self.get_argument("oauth_consumer_key") == "test_twitter_consumer_key"
assert self.get_argument("oauth_signature_method") == "HMAC-SHA1"
assert self.get_argument("oauth_version") == "1.0"
assert self.get_argument("oauth_token") == "hjkl"
self.write(dict(screen_name="foo", name="Foo"))
class AuthTest(AsyncHTTPTestCase):
def get_app(self):
return Application(
[
# test endpoints
("/openid/client/login", OpenIdClientLoginHandler, dict(test=self)),
(
"/oauth10/client/login",
OAuth1ClientLoginHandler,
dict(test=self, version="1.0"),
),
(
"/oauth10/client/request_params",
OAuth1ClientRequestParametersHandler,
dict(version="1.0"),
),
(
"/oauth10a/client/login",
OAuth1ClientLoginHandler,
dict(test=self, version="1.0a"),
),
(
"/oauth10a/client/login_coroutine",
OAuth1ClientLoginCoroutineHandler,
dict(test=self, version="1.0a"),
),
(
"/oauth10a/client/request_params",
OAuth1ClientRequestParametersHandler,
dict(version="1.0a"),
),
("/oauth2/client/login", OAuth2ClientLoginHandler, dict(test=self)),
("/facebook/client/login", FacebookClientLoginHandler, dict(test=self)),
("/twitter/client/login", TwitterClientLoginHandler, dict(test=self)),
(
"/twitter/client/authenticate",
TwitterClientAuthenticateHandler,
dict(test=self),
),
(
"/twitter/client/login_gen_coroutine",
TwitterClientLoginGenCoroutineHandler,
dict(test=self),
),
(
"/twitter/client/show_user",
TwitterClientShowUserHandler,
dict(test=self),
),
# simulated servers
("/openid/server/authenticate", OpenIdServerAuthenticateHandler),
("/oauth1/server/request_token", OAuth1ServerRequestTokenHandler),
("/oauth1/server/access_token", OAuth1ServerAccessTokenHandler),
("/facebook/server/access_token", FacebookServerAccessTokenHandler),
("/facebook/server/me", FacebookServerMeHandler),
("/twitter/server/access_token", TwitterServerAccessTokenHandler),
(r"/twitter/api/users/show/(.*)\.json", TwitterServerShowUserHandler),
(
r"/twitter/api/account/verify_credentials\.json",
TwitterServerVerifyCredentialsHandler,
),
],
http_client=self.http_client,
twitter_consumer_key="test_twitter_consumer_key",
twitter_consumer_secret="test_twitter_consumer_secret",
facebook_api_key="test_facebook_api_key",
facebook_secret="test_facebook_secret",
)
def test_openid_redirect(self):
response = self.fetch("/openid/client/login", follow_redirects=False)
self.assertEqual(response.code, 302)
self.assertTrue("/openid/server/authenticate?" in response.headers["Location"])
def test_openid_get_user(self):
response = self.fetch(
"/openid/client/login?openid.mode=blah"
"&openid.ns.ax=http://openid.net/srv/ax/1.0"
"&openid.ax.type.email=http://axschema.org/contact/email"
"&openid.ax.value.email=foo@example.com"
)
response.rethrow()
parsed = json_decode(response.body)
self.assertEqual(parsed["email"], "foo@example.com")
def test_oauth10_redirect(self):
response = self.fetch("/oauth10/client/login", follow_redirects=False)
self.assertEqual(response.code, 302)
self.assertTrue(
response.headers["Location"].endswith(
"/oauth1/server/authorize?oauth_token=zxcv"
)
)
# the cookie is base64('zxcv')|base64('1234')
self.assertTrue(
'_oauth_request_token="enhjdg==|MTIzNA=="'
in response.headers["Set-Cookie"],
response.headers["Set-Cookie"],
)
def test_oauth10_get_user(self):
response = self.fetch(
"/oauth10/client/login?oauth_token=zxcv",
headers={"Cookie": "_oauth_request_token=enhjdg==|MTIzNA=="},
)
response.rethrow()
parsed = json_decode(response.body)
self.assertEqual(parsed["email"], "foo@example.com")
self.assertEqual(parsed["access_token"], dict(key="uiop", secret="5678"))
def test_oauth10_request_parameters(self):
response = self.fetch("/oauth10/client/request_params")
response.rethrow()
parsed = json_decode(response.body)
self.assertEqual(parsed["oauth_consumer_key"], "asdf")
self.assertEqual(parsed["oauth_token"], "uiop")
self.assertTrue("oauth_nonce" in parsed)
self.assertTrue("oauth_signature" in parsed)
def test_oauth10a_redirect(self):
response = self.fetch("/oauth10a/client/login", follow_redirects=False)
self.assertEqual(response.code, 302)
self.assertTrue(
response.headers["Location"].endswith(
"/oauth1/server/authorize?oauth_token=zxcv"
)
)
# the cookie is base64('zxcv')|base64('1234')
self.assertTrue(
'_oauth_request_token="enhjdg==|MTIzNA=="'
in response.headers["Set-Cookie"],
response.headers["Set-Cookie"],
)
@unittest.skipIf(mock is None, "mock package not present")
def test_oauth10a_redirect_error(self):
with mock.patch.object(OAuth1ServerRequestTokenHandler, "get") as get:
get.side_effect = Exception("boom")
with ExpectLog(app_log, "Uncaught exception"):
response = self.fetch("/oauth10a/client/login", follow_redirects=False)
self.assertEqual(response.code, 500)
def test_oauth10a_get_user(self):
response = self.fetch(
"/oauth10a/client/login?oauth_token=zxcv",
headers={"Cookie": "_oauth_request_token=enhjdg==|MTIzNA=="},
)
response.rethrow()
parsed = json_decode(response.body)
self.assertEqual(parsed["email"], "foo@example.com")
self.assertEqual(parsed["access_token"], dict(key="uiop", secret="5678"))
def test_oauth10a_request_parameters(self):
response = self.fetch("/oauth10a/client/request_params")
response.rethrow()
parsed = json_decode(response.body)
self.assertEqual(parsed["oauth_consumer_key"], "asdf")
self.assertEqual(parsed["oauth_token"], "uiop")
self.assertTrue("oauth_nonce" in parsed)
self.assertTrue("oauth_signature" in parsed)
def test_oauth10a_get_user_coroutine_exception(self):
response = self.fetch(
"/oauth10a/client/login_coroutine?oauth_token=zxcv&fail_in_get_user=true",
headers={"Cookie": "_oauth_request_token=enhjdg==|MTIzNA=="},
)
self.assertEqual(response.code, 503)
def test_oauth2_redirect(self):
response = self.fetch("/oauth2/client/login", follow_redirects=False)
self.assertEqual(response.code, 302)
self.assertTrue("/oauth2/server/authorize?" in response.headers["Location"])
def test_facebook_login(self):
response = self.fetch("/facebook/client/login", follow_redirects=False)
self.assertEqual(response.code, 302)
self.assertTrue("/facebook/server/authorize?" in response.headers["Location"])
response = self.fetch(
"/facebook/client/login?code=1234", follow_redirects=False
)
self.assertEqual(response.code, 200)
user = json_decode(response.body)
self.assertEqual(user["access_token"], "asdf")
self.assertEqual(user["session_expires"], "3600")
def base_twitter_redirect(self, url):
# Same as test_oauth10a_redirect
response = self.fetch(url, follow_redirects=False)
self.assertEqual(response.code, 302)
self.assertTrue(
response.headers["Location"].endswith(
"/oauth1/server/authorize?oauth_token=zxcv"
)
)
# the cookie is base64('zxcv')|base64('1234')
self.assertTrue(
'_oauth_request_token="enhjdg==|MTIzNA=="'
in response.headers["Set-Cookie"],
response.headers["Set-Cookie"],
)
def test_twitter_redirect(self):
self.base_twitter_redirect("/twitter/client/login")
def test_twitter_redirect_gen_coroutine(self):
self.base_twitter_redirect("/twitter/client/login_gen_coroutine")
def test_twitter_authenticate_redirect(self):
response = self.fetch("/twitter/client/authenticate", follow_redirects=False)
self.assertEqual(response.code, 302)
self.assertTrue(
response.headers["Location"].endswith(
"/twitter/server/authenticate?oauth_token=zxcv"
),
response.headers["Location"],
)
# the cookie is base64('zxcv')|base64('1234')
self.assertTrue(
'_oauth_request_token="enhjdg==|MTIzNA=="'
in response.headers["Set-Cookie"],
response.headers["Set-Cookie"],
)
def test_twitter_get_user(self):
response = self.fetch(
"/twitter/client/login?oauth_token=zxcv",
headers={"Cookie": "_oauth_request_token=enhjdg==|MTIzNA=="},
)
response.rethrow()
parsed = json_decode(response.body)
self.assertEqual(
parsed,
{
u"access_token": {
u"key": u"hjkl",
u"screen_name": u"foo",
u"secret": u"vbnm",
},
u"name": u"Foo",
u"screen_name": u"foo",
u"username": u"foo",
},
)
def test_twitter_show_user(self):
response = self.fetch("/twitter/client/show_user?name=somebody")
response.rethrow()
self.assertEqual(
json_decode(response.body), {"name": "Somebody", "screen_name": "somebody"}
)
def test_twitter_show_user_error(self):
response = self.fetch("/twitter/client/show_user?name=error")
self.assertEqual(response.code, 500)
self.assertEqual(response.body, b"error from twitter request")
class GoogleLoginHandler(RequestHandler, GoogleOAuth2Mixin):
def initialize(self, test):
self.test = test
self._OAUTH_REDIRECT_URI = test.get_url("/client/login")
self._OAUTH_AUTHORIZE_URL = test.get_url("/google/oauth2/authorize")
self._OAUTH_ACCESS_TOKEN_URL = test.get_url("/google/oauth2/token")
@gen.coroutine
def get(self):
code = self.get_argument("code", None)
if code is not None:
# retrieve authenticate google user
access = yield self.get_authenticated_user(self._OAUTH_REDIRECT_URI, code)
user = yield self.oauth2_request(
self.test.get_url("/google/oauth2/userinfo"),
access_token=access["access_token"],
)
# return the user and access token as json
user["access_token"] = access["access_token"]
self.write(user)
else:
yield self.authorize_redirect(
redirect_uri=self._OAUTH_REDIRECT_URI,
client_id=self.settings["google_oauth"]["key"],
client_secret=self.settings["google_oauth"]["secret"],
scope=["profile", "email"],
response_type="code",
extra_params={"prompt": "select_account"},
)
class GoogleOAuth2AuthorizeHandler(RequestHandler):
def get(self):
# issue a fake auth code and redirect to redirect_uri
code = "fake-authorization-code"
self.redirect(url_concat(self.get_argument("redirect_uri"), dict(code=code)))
class GoogleOAuth2TokenHandler(RequestHandler):
def post(self):
assert self.get_argument("code") == "fake-authorization-code"
# issue a fake token
self.finish(
{"access_token": "fake-access-token", "expires_in": "never-expires"}
)
class GoogleOAuth2UserinfoHandler(RequestHandler):
def get(self):
assert self.get_argument("access_token") == "fake-access-token"
# return a fake user
self.finish({"name": "Foo", "email": "foo@example.com"})
class GoogleOAuth2Test(AsyncHTTPTestCase):
def get_app(self):
return Application(
[
# test endpoints
("/client/login", GoogleLoginHandler, dict(test=self)),
# simulated google authorization server endpoints
("/google/oauth2/authorize", GoogleOAuth2AuthorizeHandler),
("/google/oauth2/token", GoogleOAuth2TokenHandler),
("/google/oauth2/userinfo", GoogleOAuth2UserinfoHandler),
],
google_oauth={
"key": "fake_google_client_id",
"secret": "fake_google_client_secret",
},
)
def test_google_login(self):
response = self.fetch("/client/login")
self.assertDictEqual(
{
u"name": u"Foo",
u"email": u"foo@example.com",
u"access_token": u"fake-access-token",
},
json_decode(response.body),
)

View File

@@ -0,0 +1,127 @@
import os
import shutil
import subprocess
from subprocess import Popen
import sys
from tempfile import mkdtemp
import time
import unittest
class AutoreloadTest(unittest.TestCase):
def setUp(self):
self.path = mkdtemp()
def tearDown(self):
try:
shutil.rmtree(self.path)
except OSError:
# Windows disallows deleting files that are in use by
# another process, and even though we've waited for our
# child process below, it appears that its lock on these
# files is not guaranteed to be released by this point.
# Sleep and try again (once).
time.sleep(1)
shutil.rmtree(self.path)
def test_reload_module(self):
main = """\
import os
import sys
from tornado import autoreload
# This import will fail if path is not set up correctly
import testapp
print('Starting')
if 'TESTAPP_STARTED' not in os.environ:
os.environ['TESTAPP_STARTED'] = '1'
sys.stdout.flush()
autoreload._reload()
"""
# Create temporary test application
os.mkdir(os.path.join(self.path, "testapp"))
open(os.path.join(self.path, "testapp/__init__.py"), "w").close()
with open(os.path.join(self.path, "testapp/__main__.py"), "w") as f:
f.write(main)
# Make sure the tornado module under test is available to the test
# application
pythonpath = os.getcwd()
if "PYTHONPATH" in os.environ:
pythonpath += os.pathsep + os.environ["PYTHONPATH"]
p = Popen(
[sys.executable, "-m", "testapp"],
stdout=subprocess.PIPE,
cwd=self.path,
env=dict(os.environ, PYTHONPATH=pythonpath),
universal_newlines=True,
)
out = p.communicate()[0]
self.assertEqual(out, "Starting\nStarting\n")
def test_reload_wrapper_preservation(self):
# This test verifies that when `python -m tornado.autoreload`
# is used on an application that also has an internal
# autoreload, the reload wrapper is preserved on restart.
main = """\
import os
import sys
# This import will fail if path is not set up correctly
import testapp
if 'tornado.autoreload' not in sys.modules:
raise Exception('started without autoreload wrapper')
import tornado.autoreload
print('Starting')
sys.stdout.flush()
if 'TESTAPP_STARTED' not in os.environ:
os.environ['TESTAPP_STARTED'] = '1'
# Simulate an internal autoreload (one not caused
# by the wrapper).
tornado.autoreload._reload()
else:
# Exit directly so autoreload doesn't catch it.
os._exit(0)
"""
# Create temporary test application
os.mkdir(os.path.join(self.path, "testapp"))
init_file = os.path.join(self.path, "testapp", "__init__.py")
open(init_file, "w").close()
main_file = os.path.join(self.path, "testapp", "__main__.py")
with open(main_file, "w") as f:
f.write(main)
# Make sure the tornado module under test is available to the test
# application
pythonpath = os.getcwd()
if "PYTHONPATH" in os.environ:
pythonpath += os.pathsep + os.environ["PYTHONPATH"]
autoreload_proc = Popen(
[sys.executable, "-m", "tornado.autoreload", "-m", "testapp"],
stdout=subprocess.PIPE,
cwd=self.path,
env=dict(os.environ, PYTHONPATH=pythonpath),
universal_newlines=True,
)
# This timeout needs to be fairly generous for pypy due to jit
# warmup costs.
for i in range(40):
if autoreload_proc.poll() is not None:
break
time.sleep(0.1)
else:
autoreload_proc.kill()
raise Exception("subprocess failed to terminate")
out = autoreload_proc.communicate()[0]
self.assertEqual(out, "Starting\n" * 2)

View File

@@ -0,0 +1,209 @@
#
# Copyright 2012 Facebook
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from concurrent import futures
import logging
import re
import socket
import unittest
from tornado.concurrent import (
Future,
run_on_executor,
future_set_result_unless_cancelled,
)
from tornado.escape import utf8, to_unicode
from tornado import gen
from tornado.iostream import IOStream
from tornado.tcpserver import TCPServer
from tornado.testing import AsyncTestCase, bind_unused_port, gen_test
class MiscFutureTest(AsyncTestCase):
def test_future_set_result_unless_cancelled(self):
fut = Future() # type: Future[int]
future_set_result_unless_cancelled(fut, 42)
self.assertEqual(fut.result(), 42)
self.assertFalse(fut.cancelled())
fut = Future()
fut.cancel()
is_cancelled = fut.cancelled()
future_set_result_unless_cancelled(fut, 42)
self.assertEqual(fut.cancelled(), is_cancelled)
if not is_cancelled:
self.assertEqual(fut.result(), 42)
# The following series of classes demonstrate and test various styles
# of use, with and without generators and futures.
class CapServer(TCPServer):
@gen.coroutine
def handle_stream(self, stream, address):
data = yield stream.read_until(b"\n")
data = to_unicode(data)
if data == data.upper():
stream.write(b"error\talready capitalized\n")
else:
# data already has \n
stream.write(utf8("ok\t%s" % data.upper()))
stream.close()
class CapError(Exception):
pass
class BaseCapClient(object):
def __init__(self, port):
self.port = port
def process_response(self, data):
m = re.match("(.*)\t(.*)\n", to_unicode(data))
if m is None:
raise Exception("did not match")
status, message = m.groups()
if status == "ok":
return message
else:
raise CapError(message)
class GeneratorCapClient(BaseCapClient):
@gen.coroutine
def capitalize(self, request_data):
logging.debug("capitalize")
stream = IOStream(socket.socket())
logging.debug("connecting")
yield stream.connect(("127.0.0.1", self.port))
stream.write(utf8(request_data + "\n"))
logging.debug("reading")
data = yield stream.read_until(b"\n")
logging.debug("returning")
stream.close()
raise gen.Return(self.process_response(data))
class ClientTestMixin(object):
def setUp(self):
super(ClientTestMixin, self).setUp() # type: ignore
self.server = CapServer()
sock, port = bind_unused_port()
self.server.add_sockets([sock])
self.client = self.client_class(port=port)
def tearDown(self):
self.server.stop()
super(ClientTestMixin, self).tearDown() # type: ignore
def test_future(self):
future = self.client.capitalize("hello")
self.io_loop.add_future(future, self.stop)
self.wait()
self.assertEqual(future.result(), "HELLO")
def test_future_error(self):
future = self.client.capitalize("HELLO")
self.io_loop.add_future(future, self.stop)
self.wait()
self.assertRaisesRegexp(CapError, "already capitalized", future.result)
def test_generator(self):
@gen.coroutine
def f():
result = yield self.client.capitalize("hello")
self.assertEqual(result, "HELLO")
self.io_loop.run_sync(f)
def test_generator_error(self):
@gen.coroutine
def f():
with self.assertRaisesRegexp(CapError, "already capitalized"):
yield self.client.capitalize("HELLO")
self.io_loop.run_sync(f)
class GeneratorClientTest(ClientTestMixin, AsyncTestCase):
client_class = GeneratorCapClient
class RunOnExecutorTest(AsyncTestCase):
@gen_test
def test_no_calling(self):
class Object(object):
def __init__(self):
self.executor = futures.thread.ThreadPoolExecutor(1)
@run_on_executor
def f(self):
return 42
o = Object()
answer = yield o.f()
self.assertEqual(answer, 42)
@gen_test
def test_call_with_no_args(self):
class Object(object):
def __init__(self):
self.executor = futures.thread.ThreadPoolExecutor(1)
@run_on_executor()
def f(self):
return 42
o = Object()
answer = yield o.f()
self.assertEqual(answer, 42)
@gen_test
def test_call_with_executor(self):
class Object(object):
def __init__(self):
self.__executor = futures.thread.ThreadPoolExecutor(1)
@run_on_executor(executor="_Object__executor")
def f(self):
return 42
o = Object()
answer = yield o.f()
self.assertEqual(answer, 42)
@gen_test
def test_async_await(self):
class Object(object):
def __init__(self):
self.executor = futures.thread.ThreadPoolExecutor(1)
@run_on_executor()
def f(self):
return 42
o = Object()
async def f():
answer = await o.f()
return answer
result = yield f()
self.assertEqual(result, 42)
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1 @@
"school","école"
1 school école

View File

@@ -0,0 +1,130 @@
# coding: utf-8
from hashlib import md5
import unittest
from tornado.escape import utf8
from tornado.testing import AsyncHTTPTestCase
from tornado.test import httpclient_test
from tornado.web import Application, RequestHandler
try:
import pycurl # type: ignore
except ImportError:
pycurl = None
if pycurl is not None:
from tornado.curl_httpclient import CurlAsyncHTTPClient
@unittest.skipIf(pycurl is None, "pycurl module not present")
class CurlHTTPClientCommonTestCase(httpclient_test.HTTPClientCommonTestCase):
def get_http_client(self):
client = CurlAsyncHTTPClient(defaults=dict(allow_ipv6=False))
# make sure AsyncHTTPClient magic doesn't give us the wrong class
self.assertTrue(isinstance(client, CurlAsyncHTTPClient))
return client
class DigestAuthHandler(RequestHandler):
def initialize(self, username, password):
self.username = username
self.password = password
def get(self):
realm = "test"
opaque = "asdf"
# Real implementations would use a random nonce.
nonce = "1234"
auth_header = self.request.headers.get("Authorization", None)
if auth_header is not None:
auth_mode, params = auth_header.split(" ", 1)
assert auth_mode == "Digest"
param_dict = {}
for pair in params.split(","):
k, v = pair.strip().split("=", 1)
if v[0] == '"' and v[-1] == '"':
v = v[1:-1]
param_dict[k] = v
assert param_dict["realm"] == realm
assert param_dict["opaque"] == opaque
assert param_dict["nonce"] == nonce
assert param_dict["username"] == self.username
assert param_dict["uri"] == self.request.path
h1 = md5(
utf8("%s:%s:%s" % (self.username, realm, self.password))
).hexdigest()
h2 = md5(
utf8("%s:%s" % (self.request.method, self.request.path))
).hexdigest()
digest = md5(utf8("%s:%s:%s" % (h1, nonce, h2))).hexdigest()
if digest == param_dict["response"]:
self.write("ok")
else:
self.write("fail")
else:
self.set_status(401)
self.set_header(
"WWW-Authenticate",
'Digest realm="%s", nonce="%s", opaque="%s"' % (realm, nonce, opaque),
)
class CustomReasonHandler(RequestHandler):
def get(self):
self.set_status(200, "Custom reason")
class CustomFailReasonHandler(RequestHandler):
def get(self):
self.set_status(400, "Custom reason")
@unittest.skipIf(pycurl is None, "pycurl module not present")
class CurlHTTPClientTestCase(AsyncHTTPTestCase):
def setUp(self):
super(CurlHTTPClientTestCase, self).setUp()
self.http_client = self.create_client()
def get_app(self):
return Application(
[
("/digest", DigestAuthHandler, {"username": "foo", "password": "bar"}),
(
"/digest_non_ascii",
DigestAuthHandler,
{"username": "foo", "password": "barユ£"},
),
("/custom_reason", CustomReasonHandler),
("/custom_fail_reason", CustomFailReasonHandler),
]
)
def create_client(self, **kwargs):
return CurlAsyncHTTPClient(
force_instance=True, defaults=dict(allow_ipv6=False), **kwargs
)
def test_digest_auth(self):
response = self.fetch(
"/digest", auth_mode="digest", auth_username="foo", auth_password="bar"
)
self.assertEqual(response.body, b"ok")
def test_custom_reason(self):
response = self.fetch("/custom_reason")
self.assertEqual(response.reason, "Custom reason")
def test_fail_custom_reason(self):
response = self.fetch("/custom_fail_reason")
self.assertEqual(str(response.error), "HTTP 400: Custom reason")
def test_digest_auth_non_ascii(self):
response = self.fetch(
"/digest_non_ascii",
auth_mode="digest",
auth_username="foo",
auth_password="barユ£",
)
self.assertEqual(response.body, b"ok")

View File

@@ -0,0 +1,322 @@
import unittest
import tornado.escape
from tornado.escape import (
utf8,
xhtml_escape,
xhtml_unescape,
url_escape,
url_unescape,
to_unicode,
json_decode,
json_encode,
squeeze,
recursive_unicode,
)
from tornado.util import unicode_type
from typing import List, Tuple, Union, Dict, Any # noqa: F401
linkify_tests = [
# (input, linkify_kwargs, expected_output)
(
"hello http://world.com/!",
{},
u'hello <a href="http://world.com/">http://world.com/</a>!',
),
(
"hello http://world.com/with?param=true&stuff=yes",
{},
u'hello <a href="http://world.com/with?param=true&amp;stuff=yes">http://world.com/with?param=true&amp;stuff=yes</a>', # noqa: E501
),
# an opened paren followed by many chars killed Gruber's regex
(
"http://url.com/w(aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
{},
u'<a href="http://url.com/w">http://url.com/w</a>(aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa', # noqa: E501
),
# as did too many dots at the end
(
"http://url.com/withmany.......................................",
{},
u'<a href="http://url.com/withmany">http://url.com/withmany</a>.......................................', # noqa: E501
),
(
"http://url.com/withmany((((((((((((((((((((((((((((((((((a)",
{},
u'<a href="http://url.com/withmany">http://url.com/withmany</a>((((((((((((((((((((((((((((((((((a)', # noqa: E501
),
# some examples from http://daringfireball.net/2009/11/liberal_regex_for_matching_urls
# plus a fex extras (such as multiple parentheses).
(
"http://foo.com/blah_blah",
{},
u'<a href="http://foo.com/blah_blah">http://foo.com/blah_blah</a>',
),
(
"http://foo.com/blah_blah/",
{},
u'<a href="http://foo.com/blah_blah/">http://foo.com/blah_blah/</a>',
),
(
"(Something like http://foo.com/blah_blah)",
{},
u'(Something like <a href="http://foo.com/blah_blah">http://foo.com/blah_blah</a>)',
),
(
"http://foo.com/blah_blah_(wikipedia)",
{},
u'<a href="http://foo.com/blah_blah_(wikipedia)">http://foo.com/blah_blah_(wikipedia)</a>',
),
(
"http://foo.com/blah_(blah)_(wikipedia)_blah",
{},
u'<a href="http://foo.com/blah_(blah)_(wikipedia)_blah">http://foo.com/blah_(blah)_(wikipedia)_blah</a>', # noqa: E501
),
(
"(Something like http://foo.com/blah_blah_(wikipedia))",
{},
u'(Something like <a href="http://foo.com/blah_blah_(wikipedia)">http://foo.com/blah_blah_(wikipedia)</a>)', # noqa: E501
),
(
"http://foo.com/blah_blah.",
{},
u'<a href="http://foo.com/blah_blah">http://foo.com/blah_blah</a>.',
),
(
"http://foo.com/blah_blah/.",
{},
u'<a href="http://foo.com/blah_blah/">http://foo.com/blah_blah/</a>.',
),
(
"<http://foo.com/blah_blah>",
{},
u'&lt;<a href="http://foo.com/blah_blah">http://foo.com/blah_blah</a>&gt;',
),
(
"<http://foo.com/blah_blah/>",
{},
u'&lt;<a href="http://foo.com/blah_blah/">http://foo.com/blah_blah/</a>&gt;',
),
(
"http://foo.com/blah_blah,",
{},
u'<a href="http://foo.com/blah_blah">http://foo.com/blah_blah</a>,',
),
(
"http://www.example.com/wpstyle/?p=364.",
{},
u'<a href="http://www.example.com/wpstyle/?p=364">http://www.example.com/wpstyle/?p=364</a>.', # noqa: E501
),
(
"rdar://1234",
{"permitted_protocols": ["http", "rdar"]},
u'<a href="rdar://1234">rdar://1234</a>',
),
(
"rdar:/1234",
{"permitted_protocols": ["rdar"]},
u'<a href="rdar:/1234">rdar:/1234</a>',
),
(
"http://userid:password@example.com:8080",
{},
u'<a href="http://userid:password@example.com:8080">http://userid:password@example.com:8080</a>', # noqa: E501
),
(
"http://userid@example.com",
{},
u'<a href="http://userid@example.com">http://userid@example.com</a>',
),
(
"http://userid@example.com:8080",
{},
u'<a href="http://userid@example.com:8080">http://userid@example.com:8080</a>',
),
(
"http://userid:password@example.com",
{},
u'<a href="http://userid:password@example.com">http://userid:password@example.com</a>',
),
(
"message://%3c330e7f8409726r6a4ba78dkf1fd71420c1bf6ff@mail.gmail.com%3e",
{"permitted_protocols": ["http", "message"]},
u'<a href="message://%3c330e7f8409726r6a4ba78dkf1fd71420c1bf6ff@mail.gmail.com%3e">'
u"message://%3c330e7f8409726r6a4ba78dkf1fd71420c1bf6ff@mail.gmail.com%3e</a>",
),
(
u"http://\u27a1.ws/\u4a39",
{},
u'<a href="http://\u27a1.ws/\u4a39">http://\u27a1.ws/\u4a39</a>',
),
(
"<tag>http://example.com</tag>",
{},
u'&lt;tag&gt;<a href="http://example.com">http://example.com</a>&lt;/tag&gt;',
),
(
"Just a www.example.com link.",
{},
u'Just a <a href="http://www.example.com">www.example.com</a> link.',
),
(
"Just a www.example.com link.",
{"require_protocol": True},
u"Just a www.example.com link.",
),
(
"A http://reallylong.com/link/that/exceedsthelenglimit.html",
{"require_protocol": True, "shorten": True},
u'A <a href="http://reallylong.com/link/that/exceedsthelenglimit.html"'
u' title="http://reallylong.com/link/that/exceedsthelenglimit.html">http://reallylong.com/link...</a>', # noqa: E501
),
(
"A http://reallylongdomainnamethatwillbetoolong.com/hi!",
{"shorten": True},
u'A <a href="http://reallylongdomainnamethatwillbetoolong.com/hi"'
u' title="http://reallylongdomainnamethatwillbetoolong.com/hi">http://reallylongdomainnametha...</a>!', # noqa: E501
),
(
"A file:///passwords.txt and http://web.com link",
{},
u'A file:///passwords.txt and <a href="http://web.com">http://web.com</a> link',
),
(
"A file:///passwords.txt and http://web.com link",
{"permitted_protocols": ["file"]},
u'A <a href="file:///passwords.txt">file:///passwords.txt</a> and http://web.com link',
),
(
"www.external-link.com",
{"extra_params": 'rel="nofollow" class="external"'},
u'<a href="http://www.external-link.com" rel="nofollow" class="external">www.external-link.com</a>', # noqa: E501
),
(
"www.external-link.com and www.internal-link.com/blogs extra",
{
"extra_params": lambda href: 'class="internal"'
if href.startswith("http://www.internal-link.com")
else 'rel="nofollow" class="external"'
},
u'<a href="http://www.external-link.com" rel="nofollow" class="external">www.external-link.com</a>' # noqa: E501
u' and <a href="http://www.internal-link.com/blogs" class="internal">www.internal-link.com/blogs</a> extra', # noqa: E501
),
(
"www.external-link.com",
{"extra_params": lambda href: ' rel="nofollow" class="external" '},
u'<a href="http://www.external-link.com" rel="nofollow" class="external">www.external-link.com</a>', # noqa: E501
),
] # type: List[Tuple[Union[str, bytes], Dict[str, Any], str]]
class EscapeTestCase(unittest.TestCase):
def test_linkify(self):
for text, kwargs, html in linkify_tests:
linked = tornado.escape.linkify(text, **kwargs)
self.assertEqual(linked, html)
def test_xhtml_escape(self):
tests = [
("<foo>", "&lt;foo&gt;"),
(u"<foo>", u"&lt;foo&gt;"),
(b"<foo>", b"&lt;foo&gt;"),
("<>&\"'", "&lt;&gt;&amp;&quot;&#39;"),
("&amp;", "&amp;amp;"),
(u"<\u00e9>", u"&lt;\u00e9&gt;"),
(b"<\xc3\xa9>", b"&lt;\xc3\xa9&gt;"),
] # type: List[Tuple[Union[str, bytes], Union[str, bytes]]]
for unescaped, escaped in tests:
self.assertEqual(utf8(xhtml_escape(unescaped)), utf8(escaped))
self.assertEqual(utf8(unescaped), utf8(xhtml_unescape(escaped)))
def test_xhtml_unescape_numeric(self):
tests = [
("foo&#32;bar", "foo bar"),
("foo&#x20;bar", "foo bar"),
("foo&#X20;bar", "foo bar"),
("foo&#xabc;bar", u"foo\u0abcbar"),
("foo&#xyz;bar", "foo&#xyz;bar"), # invalid encoding
("foo&#;bar", "foo&#;bar"), # invalid encoding
("foo&#x;bar", "foo&#x;bar"), # invalid encoding
]
for escaped, unescaped in tests:
self.assertEqual(unescaped, xhtml_unescape(escaped))
def test_url_escape_unicode(self):
tests = [
# byte strings are passed through as-is
(u"\u00e9".encode("utf8"), "%C3%A9"),
(u"\u00e9".encode("latin1"), "%E9"),
# unicode strings become utf8
(u"\u00e9", "%C3%A9"),
] # type: List[Tuple[Union[str, bytes], str]]
for unescaped, escaped in tests:
self.assertEqual(url_escape(unescaped), escaped)
def test_url_unescape_unicode(self):
tests = [
("%C3%A9", u"\u00e9", "utf8"),
("%C3%A9", u"\u00c3\u00a9", "latin1"),
("%C3%A9", utf8(u"\u00e9"), None),
]
for escaped, unescaped, encoding in tests:
# input strings to url_unescape should only contain ascii
# characters, but make sure the function accepts both byte
# and unicode strings.
self.assertEqual(url_unescape(to_unicode(escaped), encoding), unescaped)
self.assertEqual(url_unescape(utf8(escaped), encoding), unescaped)
def test_url_escape_quote_plus(self):
unescaped = "+ #%"
plus_escaped = "%2B+%23%25"
escaped = "%2B%20%23%25"
self.assertEqual(url_escape(unescaped), plus_escaped)
self.assertEqual(url_escape(unescaped, plus=False), escaped)
self.assertEqual(url_unescape(plus_escaped), unescaped)
self.assertEqual(url_unescape(escaped, plus=False), unescaped)
self.assertEqual(url_unescape(plus_escaped, encoding=None), utf8(unescaped))
self.assertEqual(
url_unescape(escaped, encoding=None, plus=False), utf8(unescaped)
)
def test_escape_return_types(self):
# On python2 the escape methods should generally return the same
# type as their argument
self.assertEqual(type(xhtml_escape("foo")), str)
self.assertEqual(type(xhtml_escape(u"foo")), unicode_type)
def test_json_decode(self):
# json_decode accepts both bytes and unicode, but strings it returns
# are always unicode.
self.assertEqual(json_decode(b'"foo"'), u"foo")
self.assertEqual(json_decode(u'"foo"'), u"foo")
# Non-ascii bytes are interpreted as utf8
self.assertEqual(json_decode(utf8(u'"\u00e9"')), u"\u00e9")
def test_json_encode(self):
# json deals with strings, not bytes. On python 2 byte strings will
# convert automatically if they are utf8; on python 3 byte strings
# are not allowed.
self.assertEqual(json_decode(json_encode(u"\u00e9")), u"\u00e9")
if bytes is str:
self.assertEqual(json_decode(json_encode(utf8(u"\u00e9"))), u"\u00e9")
self.assertRaises(UnicodeDecodeError, json_encode, b"\xe9")
def test_squeeze(self):
self.assertEqual(
squeeze(u"sequences of whitespace chars"),
u"sequences of whitespace chars",
)
def test_recursive_unicode(self):
tests = {
"dict": {b"foo": b"bar"},
"list": [b"foo", b"bar"],
"tuple": (b"foo", b"bar"),
"bytes": b"foo",
}
self.assertEqual(recursive_unicode(tests["dict"]), {u"foo": u"bar"})
self.assertEqual(recursive_unicode(tests["list"]), [u"foo", u"bar"])
self.assertEqual(recursive_unicode(tests["tuple"]), (u"foo", u"bar"))
self.assertEqual(recursive_unicode(tests["bytes"]), u"foo")

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,47 @@
# SOME DESCRIPTIVE TITLE.
# Copyright (C) YEAR THE PACKAGE'S COPYRIGHT HOLDER
# This file is distributed under the same license as the PACKAGE package.
# FIRST AUTHOR <EMAIL@ADDRESS>, YEAR.
#
#, fuzzy
msgid ""
msgstr ""
"Project-Id-Version: PACKAGE VERSION\n"
"Report-Msgid-Bugs-To: \n"
"POT-Creation-Date: 2015-01-27 11:05+0300\n"
"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n"
"Last-Translator: FULL NAME <EMAIL@ADDRESS>\n"
"Language-Team: LANGUAGE <LL@li.org>\n"
"Language: \n"
"MIME-Version: 1.0\n"
"Content-Type: text/plain; charset=utf-8\n"
"Content-Transfer-Encoding: 8bit\n"
"Plural-Forms: nplurals=2; plural=(n > 1);\n"
#: extract_me.py:11
msgid "school"
msgstr "école"
#: extract_me.py:12
msgctxt "law"
msgid "right"
msgstr "le droit"
#: extract_me.py:13
msgctxt "good"
msgid "right"
msgstr "le bien"
#: extract_me.py:14
msgctxt "organization"
msgid "club"
msgid_plural "clubs"
msgstr[0] "le club"
msgstr[1] "les clubs"
#: extract_me.py:15
msgctxt "stick"
msgid "club"
msgid_plural "clubs"
msgstr[0] "le bâton"
msgstr[1] "les bâtons"

View File

@@ -0,0 +1,58 @@
import socket
from tornado.http1connection import HTTP1Connection
from tornado.httputil import HTTPMessageDelegate
from tornado.iostream import IOStream
from tornado.locks import Event
from tornado.netutil import add_accept_handler
from tornado.testing import AsyncTestCase, bind_unused_port, gen_test
class HTTP1ConnectionTest(AsyncTestCase):
def setUp(self):
super(HTTP1ConnectionTest, self).setUp()
self.asyncSetUp()
@gen_test
def asyncSetUp(self):
listener, port = bind_unused_port()
event = Event()
def accept_callback(conn, addr):
self.server_stream = IOStream(conn)
self.addCleanup(self.server_stream.close)
event.set()
add_accept_handler(listener, accept_callback)
self.client_stream = IOStream(socket.socket())
self.addCleanup(self.client_stream.close)
yield [self.client_stream.connect(("127.0.0.1", port)), event.wait()]
self.io_loop.remove_handler(listener)
listener.close()
@gen_test
def test_http10_no_content_length(self):
# Regression test for a bug in which can_keep_alive would crash
# for an HTTP/1.0 (not 1.1) response with no content-length.
conn = HTTP1Connection(self.client_stream, True)
self.server_stream.write(b"HTTP/1.0 200 Not Modified\r\n\r\nhello")
self.server_stream.close()
event = Event()
test = self
body = []
class Delegate(HTTPMessageDelegate):
def headers_received(self, start_line, headers):
test.code = start_line.code
def data_received(self, data):
body.append(data)
def finish(self):
event.set()
yield conn.read_response(Delegate())
yield event.wait()
self.assertEqual(self.code, 200)
self.assertEqual(b"".join(body), b"hello")

View File

@@ -0,0 +1,832 @@
# -*- coding: utf-8 -*-
import base64
import binascii
from contextlib import closing
import copy
import threading
import datetime
from io import BytesIO
import subprocess
import sys
import time
import typing # noqa: F401
import unicodedata
import unittest
from tornado.escape import utf8, native_str, to_unicode
from tornado import gen
from tornado.httpclient import (
HTTPRequest,
HTTPResponse,
_RequestProxy,
HTTPError,
HTTPClient,
)
from tornado.httpserver import HTTPServer
from tornado.ioloop import IOLoop
from tornado.iostream import IOStream
from tornado.log import gen_log, app_log
from tornado import netutil
from tornado.testing import AsyncHTTPTestCase, bind_unused_port, gen_test, ExpectLog
from tornado.test.util import skipOnTravis
from tornado.web import Application, RequestHandler, url
from tornado.httputil import format_timestamp, HTTPHeaders
class HelloWorldHandler(RequestHandler):
def get(self):
name = self.get_argument("name", "world")
self.set_header("Content-Type", "text/plain")
self.finish("Hello %s!" % name)
class PostHandler(RequestHandler):
def post(self):
self.finish(
"Post arg1: %s, arg2: %s"
% (self.get_argument("arg1"), self.get_argument("arg2"))
)
class PutHandler(RequestHandler):
def put(self):
self.write("Put body: ")
self.write(self.request.body)
class RedirectHandler(RequestHandler):
def prepare(self):
self.write("redirects can have bodies too")
self.redirect(
self.get_argument("url"), status=int(self.get_argument("status", "302"))
)
class RedirectWithoutLocationHandler(RequestHandler):
def prepare(self):
# For testing error handling of a redirect with no location header.
self.set_status(301)
self.finish()
class ChunkHandler(RequestHandler):
@gen.coroutine
def get(self):
self.write("asdf")
self.flush()
# Wait a bit to ensure the chunks are sent and received separately.
yield gen.sleep(0.01)
self.write("qwer")
class AuthHandler(RequestHandler):
def get(self):
self.finish(self.request.headers["Authorization"])
class CountdownHandler(RequestHandler):
def get(self, count):
count = int(count)
if count > 0:
self.redirect(self.reverse_url("countdown", count - 1))
else:
self.write("Zero")
class EchoPostHandler(RequestHandler):
def post(self):
self.write(self.request.body)
class UserAgentHandler(RequestHandler):
def get(self):
self.write(self.request.headers.get("User-Agent", "User agent not set"))
class ContentLength304Handler(RequestHandler):
def get(self):
self.set_status(304)
self.set_header("Content-Length", 42)
def _clear_headers_for_304(self):
# Tornado strips content-length from 304 responses, but here we
# want to simulate servers that include the headers anyway.
pass
class PatchHandler(RequestHandler):
def patch(self):
"Return the request payload - so we can check it is being kept"
self.write(self.request.body)
class AllMethodsHandler(RequestHandler):
SUPPORTED_METHODS = RequestHandler.SUPPORTED_METHODS + ("OTHER",) # type: ignore
def method(self):
self.write(self.request.method)
get = head = post = put = delete = options = patch = other = method # type: ignore
class SetHeaderHandler(RequestHandler):
def get(self):
# Use get_arguments for keys to get strings, but
# request.arguments for values to get bytes.
for k, v in zip(self.get_arguments("k"), self.request.arguments["v"]):
self.set_header(k, v)
# These tests end up getting run redundantly: once here with the default
# HTTPClient implementation, and then again in each implementation's own
# test suite.
class HTTPClientCommonTestCase(AsyncHTTPTestCase):
def get_app(self):
return Application(
[
url("/hello", HelloWorldHandler),
url("/post", PostHandler),
url("/put", PutHandler),
url("/redirect", RedirectHandler),
url("/redirect_without_location", RedirectWithoutLocationHandler),
url("/chunk", ChunkHandler),
url("/auth", AuthHandler),
url("/countdown/([0-9]+)", CountdownHandler, name="countdown"),
url("/echopost", EchoPostHandler),
url("/user_agent", UserAgentHandler),
url("/304_with_content_length", ContentLength304Handler),
url("/all_methods", AllMethodsHandler),
url("/patch", PatchHandler),
url("/set_header", SetHeaderHandler),
],
gzip=True,
)
def test_patch_receives_payload(self):
body = b"some patch data"
response = self.fetch("/patch", method="PATCH", body=body)
self.assertEqual(response.code, 200)
self.assertEqual(response.body, body)
@skipOnTravis
def test_hello_world(self):
response = self.fetch("/hello")
self.assertEqual(response.code, 200)
self.assertEqual(response.headers["Content-Type"], "text/plain")
self.assertEqual(response.body, b"Hello world!")
self.assertEqual(int(response.request_time), 0)
response = self.fetch("/hello?name=Ben")
self.assertEqual(response.body, b"Hello Ben!")
def test_streaming_callback(self):
# streaming_callback is also tested in test_chunked
chunks = [] # type: typing.List[bytes]
response = self.fetch("/hello", streaming_callback=chunks.append)
# with streaming_callback, data goes to the callback and not response.body
self.assertEqual(chunks, [b"Hello world!"])
self.assertFalse(response.body)
def test_post(self):
response = self.fetch("/post", method="POST", body="arg1=foo&arg2=bar")
self.assertEqual(response.code, 200)
self.assertEqual(response.body, b"Post arg1: foo, arg2: bar")
def test_chunked(self):
response = self.fetch("/chunk")
self.assertEqual(response.body, b"asdfqwer")
chunks = [] # type: typing.List[bytes]
response = self.fetch("/chunk", streaming_callback=chunks.append)
self.assertEqual(chunks, [b"asdf", b"qwer"])
self.assertFalse(response.body)
def test_chunked_close(self):
# test case in which chunks spread read-callback processing
# over several ioloop iterations, but the connection is already closed.
sock, port = bind_unused_port()
with closing(sock):
@gen.coroutine
def accept_callback(conn, address):
# fake an HTTP server using chunked encoding where the final chunks
# and connection close all happen at once
stream = IOStream(conn)
request_data = yield stream.read_until(b"\r\n\r\n")
if b"HTTP/1." not in request_data:
self.skipTest("requires HTTP/1.x")
yield stream.write(
b"""\
HTTP/1.1 200 OK
Transfer-Encoding: chunked
1
1
1
2
0
""".replace(
b"\n", b"\r\n"
)
)
stream.close()
netutil.add_accept_handler(sock, accept_callback) # type: ignore
resp = self.fetch("http://127.0.0.1:%d/" % port)
resp.rethrow()
self.assertEqual(resp.body, b"12")
self.io_loop.remove_handler(sock.fileno())
def test_basic_auth(self):
# This test data appears in section 2 of RFC 7617.
self.assertEqual(
self.fetch(
"/auth", auth_username="Aladdin", auth_password="open sesame"
).body,
b"Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==",
)
def test_basic_auth_explicit_mode(self):
self.assertEqual(
self.fetch(
"/auth",
auth_username="Aladdin",
auth_password="open sesame",
auth_mode="basic",
).body,
b"Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==",
)
def test_basic_auth_unicode(self):
# This test data appears in section 2.1 of RFC 7617.
self.assertEqual(
self.fetch("/auth", auth_username="test", auth_password="123£").body,
b"Basic dGVzdDoxMjPCow==",
)
# The standard mandates NFC. Give it a decomposed username
# and ensure it is normalized to composed form.
username = unicodedata.normalize("NFD", u"josé")
self.assertEqual(
self.fetch("/auth", auth_username=username, auth_password="səcrət").body,
b"Basic am9zw6k6c8mZY3LJmXQ=",
)
def test_unsupported_auth_mode(self):
# curl and simple clients handle errors a bit differently; the
# important thing is that they don't fall back to basic auth
# on an unknown mode.
with ExpectLog(gen_log, "uncaught exception", required=False):
with self.assertRaises((ValueError, HTTPError)):
self.fetch(
"/auth",
auth_username="Aladdin",
auth_password="open sesame",
auth_mode="asdf",
raise_error=True,
)
def test_follow_redirect(self):
response = self.fetch("/countdown/2", follow_redirects=False)
self.assertEqual(302, response.code)
self.assertTrue(response.headers["Location"].endswith("/countdown/1"))
response = self.fetch("/countdown/2")
self.assertEqual(200, response.code)
self.assertTrue(response.effective_url.endswith("/countdown/0"))
self.assertEqual(b"Zero", response.body)
def test_redirect_without_location(self):
response = self.fetch("/redirect_without_location", follow_redirects=True)
# If there is no location header, the redirect response should
# just be returned as-is. (This should arguably raise an
# error, but libcurl doesn't treat this as an error, so we
# don't either).
self.assertEqual(301, response.code)
def test_redirect_put_with_body(self):
response = self.fetch(
"/redirect?url=/put&status=307", method="PUT", body="hello"
)
self.assertEqual(response.body, b"Put body: hello")
def test_redirect_put_without_body(self):
# This "without body" edge case is similar to what happens with body_producer.
response = self.fetch(
"/redirect?url=/put&status=307",
method="PUT",
allow_nonstandard_methods=True,
)
self.assertEqual(response.body, b"Put body: ")
def test_method_after_redirect(self):
# Legacy redirect codes (301, 302) convert POST requests to GET.
for status in [301, 302, 303]:
url = "/redirect?url=/all_methods&status=%d" % status
resp = self.fetch(url, method="POST", body=b"")
self.assertEqual(b"GET", resp.body)
# Other methods are left alone.
for method in ["GET", "OPTIONS", "PUT", "DELETE"]:
resp = self.fetch(url, method=method, allow_nonstandard_methods=True)
self.assertEqual(utf8(method), resp.body)
# HEAD is different so check it separately.
resp = self.fetch(url, method="HEAD")
self.assertEqual(200, resp.code)
self.assertEqual(b"", resp.body)
# Newer redirects always preserve the original method.
for status in [307, 308]:
url = "/redirect?url=/all_methods&status=307"
for method in ["GET", "OPTIONS", "POST", "PUT", "DELETE"]:
resp = self.fetch(url, method=method, allow_nonstandard_methods=True)
self.assertEqual(method, to_unicode(resp.body))
resp = self.fetch(url, method="HEAD")
self.assertEqual(200, resp.code)
self.assertEqual(b"", resp.body)
def test_credentials_in_url(self):
url = self.get_url("/auth").replace("http://", "http://me:secret@")
response = self.fetch(url)
self.assertEqual(b"Basic " + base64.b64encode(b"me:secret"), response.body)
def test_body_encoding(self):
unicode_body = u"\xe9"
byte_body = binascii.a2b_hex(b"e9")
# unicode string in body gets converted to utf8
response = self.fetch(
"/echopost",
method="POST",
body=unicode_body,
headers={"Content-Type": "application/blah"},
)
self.assertEqual(response.headers["Content-Length"], "2")
self.assertEqual(response.body, utf8(unicode_body))
# byte strings pass through directly
response = self.fetch(
"/echopost",
method="POST",
body=byte_body,
headers={"Content-Type": "application/blah"},
)
self.assertEqual(response.headers["Content-Length"], "1")
self.assertEqual(response.body, byte_body)
# Mixing unicode in headers and byte string bodies shouldn't
# break anything
response = self.fetch(
"/echopost",
method="POST",
body=byte_body,
headers={"Content-Type": "application/blah"},
user_agent=u"foo",
)
self.assertEqual(response.headers["Content-Length"], "1")
self.assertEqual(response.body, byte_body)
def test_types(self):
response = self.fetch("/hello")
self.assertEqual(type(response.body), bytes)
self.assertEqual(type(response.headers["Content-Type"]), str)
self.assertEqual(type(response.code), int)
self.assertEqual(type(response.effective_url), str)
def test_header_callback(self):
first_line = []
headers = {}
chunks = []
def header_callback(header_line):
if header_line.startswith("HTTP/1.1 101"):
# Upgrading to HTTP/2
pass
elif header_line.startswith("HTTP/"):
first_line.append(header_line)
elif header_line != "\r\n":
k, v = header_line.split(":", 1)
headers[k.lower()] = v.strip()
def streaming_callback(chunk):
# All header callbacks are run before any streaming callbacks,
# so the header data is available to process the data as it
# comes in.
self.assertEqual(headers["content-type"], "text/html; charset=UTF-8")
chunks.append(chunk)
self.fetch(
"/chunk",
header_callback=header_callback,
streaming_callback=streaming_callback,
)
self.assertEqual(len(first_line), 1, first_line)
self.assertRegexpMatches(first_line[0], "HTTP/[0-9]\\.[0-9] 200.*\r\n")
self.assertEqual(chunks, [b"asdf", b"qwer"])
@gen_test
def test_configure_defaults(self):
defaults = dict(user_agent="TestDefaultUserAgent", allow_ipv6=False)
# Construct a new instance of the configured client class
client = self.http_client.__class__(force_instance=True, defaults=defaults)
try:
response = yield client.fetch(self.get_url("/user_agent"))
self.assertEqual(response.body, b"TestDefaultUserAgent")
finally:
client.close()
def test_header_types(self):
# Header values may be passed as character or utf8 byte strings,
# in a plain dictionary or an HTTPHeaders object.
# Keys must always be the native str type.
# All combinations should have the same results on the wire.
for value in [u"MyUserAgent", b"MyUserAgent"]:
for container in [dict, HTTPHeaders]:
headers = container()
headers["User-Agent"] = value
resp = self.fetch("/user_agent", headers=headers)
self.assertEqual(
resp.body,
b"MyUserAgent",
"response=%r, value=%r, container=%r"
% (resp.body, value, container),
)
def test_multi_line_headers(self):
# Multi-line http headers are rare but rfc-allowed
# http://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html#sec4.2
sock, port = bind_unused_port()
with closing(sock):
@gen.coroutine
def accept_callback(conn, address):
stream = IOStream(conn)
request_data = yield stream.read_until(b"\r\n\r\n")
if b"HTTP/1." not in request_data:
self.skipTest("requires HTTP/1.x")
yield stream.write(
b"""\
HTTP/1.1 200 OK
X-XSS-Protection: 1;
\tmode=block
""".replace(
b"\n", b"\r\n"
)
)
stream.close()
netutil.add_accept_handler(sock, accept_callback) # type: ignore
resp = self.fetch("http://127.0.0.1:%d/" % port)
resp.rethrow()
self.assertEqual(resp.headers["X-XSS-Protection"], "1; mode=block")
self.io_loop.remove_handler(sock.fileno())
def test_304_with_content_length(self):
# According to the spec 304 responses SHOULD NOT include
# Content-Length or other entity headers, but some servers do it
# anyway.
# http://www.w3.org/Protocols/rfc2616/rfc2616-sec10.html#sec10.3.5
response = self.fetch("/304_with_content_length")
self.assertEqual(response.code, 304)
self.assertEqual(response.headers["Content-Length"], "42")
@gen_test
def test_future_interface(self):
response = yield self.http_client.fetch(self.get_url("/hello"))
self.assertEqual(response.body, b"Hello world!")
@gen_test
def test_future_http_error(self):
with self.assertRaises(HTTPError) as context:
yield self.http_client.fetch(self.get_url("/notfound"))
self.assertEqual(context.exception.code, 404)
self.assertEqual(context.exception.response.code, 404)
@gen_test
def test_future_http_error_no_raise(self):
response = yield self.http_client.fetch(
self.get_url("/notfound"), raise_error=False
)
self.assertEqual(response.code, 404)
@gen_test
def test_reuse_request_from_response(self):
# The response.request attribute should be an HTTPRequest, not
# a _RequestProxy.
# This test uses self.http_client.fetch because self.fetch calls
# self.get_url on the input unconditionally.
url = self.get_url("/hello")
response = yield self.http_client.fetch(url)
self.assertEqual(response.request.url, url)
self.assertTrue(isinstance(response.request, HTTPRequest))
response2 = yield self.http_client.fetch(response.request)
self.assertEqual(response2.body, b"Hello world!")
@gen_test
def test_bind_source_ip(self):
url = self.get_url("/hello")
request = HTTPRequest(url, network_interface="127.0.0.1")
response = yield self.http_client.fetch(request)
self.assertEqual(response.code, 200)
with self.assertRaises((ValueError, HTTPError)) as context:
request = HTTPRequest(url, network_interface="not-interface-or-ip")
yield self.http_client.fetch(request)
self.assertIn("not-interface-or-ip", str(context.exception))
def test_all_methods(self):
for method in ["GET", "DELETE", "OPTIONS"]:
response = self.fetch("/all_methods", method=method)
self.assertEqual(response.body, utf8(method))
for method in ["POST", "PUT", "PATCH"]:
response = self.fetch("/all_methods", method=method, body=b"")
self.assertEqual(response.body, utf8(method))
response = self.fetch("/all_methods", method="HEAD")
self.assertEqual(response.body, b"")
response = self.fetch(
"/all_methods", method="OTHER", allow_nonstandard_methods=True
)
self.assertEqual(response.body, b"OTHER")
def test_body_sanity_checks(self):
# These methods require a body.
for method in ("POST", "PUT", "PATCH"):
with self.assertRaises(ValueError) as context:
self.fetch("/all_methods", method=method, raise_error=True)
self.assertIn("must not be None", str(context.exception))
resp = self.fetch(
"/all_methods", method=method, allow_nonstandard_methods=True
)
self.assertEqual(resp.code, 200)
# These methods don't allow a body.
for method in ("GET", "DELETE", "OPTIONS"):
with self.assertRaises(ValueError) as context:
self.fetch(
"/all_methods", method=method, body=b"asdf", raise_error=True
)
self.assertIn("must be None", str(context.exception))
# In most cases this can be overridden, but curl_httpclient
# does not allow body with a GET at all.
if method != "GET":
self.fetch(
"/all_methods",
method=method,
body=b"asdf",
allow_nonstandard_methods=True,
raise_error=True,
)
self.assertEqual(resp.code, 200)
# This test causes odd failures with the combination of
# curl_httpclient (at least with the version of libcurl available
# on ubuntu 12.04), TwistedIOLoop, and epoll. For POST (but not PUT),
# curl decides the response came back too soon and closes the connection
# to start again. It does this *before* telling the socket callback to
# unregister the FD. Some IOLoop implementations have special kernel
# integration to discover this immediately. Tornado's IOLoops
# ignore errors on remove_handler to accommodate this behavior, but
# Twisted's reactor does not. The removeReader call fails and so
# do all future removeAll calls (which our tests do at cleanup).
#
# def test_post_307(self):
# response = self.fetch("/redirect?status=307&url=/post",
# method="POST", body=b"arg1=foo&arg2=bar")
# self.assertEqual(response.body, b"Post arg1: foo, arg2: bar")
def test_put_307(self):
response = self.fetch(
"/redirect?status=307&url=/put", method="PUT", body=b"hello"
)
response.rethrow()
self.assertEqual(response.body, b"Put body: hello")
def test_non_ascii_header(self):
# Non-ascii headers are sent as latin1.
response = self.fetch("/set_header?k=foo&v=%E9")
response.rethrow()
self.assertEqual(response.headers["Foo"], native_str(u"\u00e9"))
def test_response_times(self):
# A few simple sanity checks of the response time fields to
# make sure they're using the right basis (between the
# wall-time and monotonic clocks).
start_time = time.time()
response = self.fetch("/hello")
response.rethrow()
self.assertGreaterEqual(response.request_time, 0)
self.assertLess(response.request_time, 1.0)
# A very crude check to make sure that start_time is based on
# wall time and not the monotonic clock.
self.assertLess(abs(response.start_time - start_time), 1.0)
for k, v in response.time_info.items():
self.assertTrue(0 <= v < 1.0, "time_info[%s] out of bounds: %s" % (k, v))
@gen_test
def test_error_after_cancel(self):
fut = self.http_client.fetch(self.get_url("/404"))
self.assertTrue(fut.cancel())
with ExpectLog(app_log, "Exception after Future was cancelled") as el:
# We can't wait on the cancelled Future any more, so just
# let the IOLoop run until the exception gets logged (or
# not, in which case we exit the loop and ExpectLog will
# raise).
for i in range(100):
yield gen.sleep(0.01)
if el.logged_stack:
break
class RequestProxyTest(unittest.TestCase):
def test_request_set(self):
proxy = _RequestProxy(
HTTPRequest("http://example.com/", user_agent="foo"), dict()
)
self.assertEqual(proxy.user_agent, "foo")
def test_default_set(self):
proxy = _RequestProxy(
HTTPRequest("http://example.com/"), dict(network_interface="foo")
)
self.assertEqual(proxy.network_interface, "foo")
def test_both_set(self):
proxy = _RequestProxy(
HTTPRequest("http://example.com/", proxy_host="foo"), dict(proxy_host="bar")
)
self.assertEqual(proxy.proxy_host, "foo")
def test_neither_set(self):
proxy = _RequestProxy(HTTPRequest("http://example.com/"), dict())
self.assertIs(proxy.auth_username, None)
def test_bad_attribute(self):
proxy = _RequestProxy(HTTPRequest("http://example.com/"), dict())
with self.assertRaises(AttributeError):
proxy.foo
def test_defaults_none(self):
proxy = _RequestProxy(HTTPRequest("http://example.com/"), None)
self.assertIs(proxy.auth_username, None)
class HTTPResponseTestCase(unittest.TestCase):
def test_str(self):
response = HTTPResponse( # type: ignore
HTTPRequest("http://example.com"), 200, headers={}, buffer=BytesIO()
)
s = str(response)
self.assertTrue(s.startswith("HTTPResponse("))
self.assertIn("code=200", s)
class SyncHTTPClientTest(unittest.TestCase):
def setUp(self):
self.server_ioloop = IOLoop()
event = threading.Event()
@gen.coroutine
def init_server():
sock, self.port = bind_unused_port()
app = Application([("/", HelloWorldHandler)])
self.server = HTTPServer(app)
self.server.add_socket(sock)
event.set()
def start():
self.server_ioloop.run_sync(init_server)
self.server_ioloop.start()
self.server_thread = threading.Thread(target=start)
self.server_thread.start()
event.wait()
self.http_client = HTTPClient()
def tearDown(self):
def stop_server():
self.server.stop()
# Delay the shutdown of the IOLoop by several iterations because
# the server may still have some cleanup work left when
# the client finishes with the response (this is noticeable
# with http/2, which leaves a Future with an unexamined
# StreamClosedError on the loop).
@gen.coroutine
def slow_stop():
yield self.server.close_all_connections()
# The number of iterations is difficult to predict. Typically,
# one is sufficient, although sometimes it needs more.
for i in range(5):
yield
self.server_ioloop.stop()
self.server_ioloop.add_callback(slow_stop)
self.server_ioloop.add_callback(stop_server)
self.server_thread.join()
self.http_client.close()
self.server_ioloop.close(all_fds=True)
def get_url(self, path):
return "http://127.0.0.1:%d%s" % (self.port, path)
def test_sync_client(self):
response = self.http_client.fetch(self.get_url("/"))
self.assertEqual(b"Hello world!", response.body)
def test_sync_client_error(self):
# Synchronous HTTPClient raises errors directly; no need for
# response.rethrow()
with self.assertRaises(HTTPError) as assertion:
self.http_client.fetch(self.get_url("/notfound"))
self.assertEqual(assertion.exception.code, 404)
class SyncHTTPClientSubprocessTest(unittest.TestCase):
def test_destructor_log(self):
# Regression test for
# https://github.com/tornadoweb/tornado/issues/2539
#
# In the past, the following program would log an
# "inconsistent AsyncHTTPClient cache" error from a destructor
# when the process is shutting down. The shutdown process is
# subtle and I don't fully understand it; the failure does not
# manifest if that lambda isn't there or is a simpler object
# like an int (nor does it manifest in the tornado test suite
# as a whole, which is why we use this subprocess).
proc = subprocess.run(
[
sys.executable,
"-c",
"from tornado.httpclient import HTTPClient; f = lambda: None; c = HTTPClient()",
],
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
check=True,
)
if proc.stdout:
print("STDOUT:")
print(to_unicode(proc.stdout))
if proc.stdout:
self.fail("subprocess produced unexpected output")
class HTTPRequestTestCase(unittest.TestCase):
def test_headers(self):
request = HTTPRequest("http://example.com", headers={"foo": "bar"})
self.assertEqual(request.headers, {"foo": "bar"})
def test_headers_setter(self):
request = HTTPRequest("http://example.com")
request.headers = {"bar": "baz"} # type: ignore
self.assertEqual(request.headers, {"bar": "baz"})
def test_null_headers_setter(self):
request = HTTPRequest("http://example.com")
request.headers = None # type: ignore
self.assertEqual(request.headers, {})
def test_body(self):
request = HTTPRequest("http://example.com", body="foo")
self.assertEqual(request.body, utf8("foo"))
def test_body_setter(self):
request = HTTPRequest("http://example.com")
request.body = "foo" # type: ignore
self.assertEqual(request.body, utf8("foo"))
def test_if_modified_since(self):
http_date = datetime.datetime.utcnow()
request = HTTPRequest("http://example.com", if_modified_since=http_date)
self.assertEqual(
request.headers, {"If-Modified-Since": format_timestamp(http_date)}
)
class HTTPErrorTestCase(unittest.TestCase):
def test_copy(self):
e = HTTPError(403)
e2 = copy.copy(e)
self.assertIsNot(e, e2)
self.assertEqual(e.code, e2.code)
def test_plain_error(self):
e = HTTPError(403)
self.assertEqual(str(e), "HTTP 403: Forbidden")
self.assertEqual(repr(e), "HTTP 403: Forbidden")
def test_error_with_response(self):
resp = HTTPResponse(HTTPRequest("http://example.com/"), 403)
with self.assertRaises(HTTPError) as cm:
resp.rethrow()
e = cm.exception
self.assertEqual(str(e), "HTTP 403: Forbidden")
self.assertEqual(repr(e), "HTTP 403: Forbidden")

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,522 @@
# -*- coding: utf-8 -*-
from tornado.httputil import (
url_concat,
parse_multipart_form_data,
HTTPHeaders,
format_timestamp,
HTTPServerRequest,
parse_request_start_line,
parse_cookie,
qs_to_qsl,
HTTPInputError,
HTTPFile,
)
from tornado.escape import utf8, native_str
from tornado.log import gen_log
from tornado.testing import ExpectLog
import copy
import datetime
import logging
import pickle
import time
import urllib.parse
import unittest
from typing import Tuple, Dict, List
def form_data_args() -> Tuple[Dict[str, List[bytes]], Dict[str, List[HTTPFile]]]:
"""Return two empty dicts suitable for use with parse_multipart_form_data.
mypy insists on type annotations for dict literals, so this lets us avoid
the verbose types throughout this test.
"""
return {}, {}
class TestUrlConcat(unittest.TestCase):
def test_url_concat_no_query_params(self):
url = url_concat("https://localhost/path", [("y", "y"), ("z", "z")])
self.assertEqual(url, "https://localhost/path?y=y&z=z")
def test_url_concat_encode_args(self):
url = url_concat("https://localhost/path", [("y", "/y"), ("z", "z")])
self.assertEqual(url, "https://localhost/path?y=%2Fy&z=z")
def test_url_concat_trailing_q(self):
url = url_concat("https://localhost/path?", [("y", "y"), ("z", "z")])
self.assertEqual(url, "https://localhost/path?y=y&z=z")
def test_url_concat_q_with_no_trailing_amp(self):
url = url_concat("https://localhost/path?x", [("y", "y"), ("z", "z")])
self.assertEqual(url, "https://localhost/path?x=&y=y&z=z")
def test_url_concat_trailing_amp(self):
url = url_concat("https://localhost/path?x&", [("y", "y"), ("z", "z")])
self.assertEqual(url, "https://localhost/path?x=&y=y&z=z")
def test_url_concat_mult_params(self):
url = url_concat("https://localhost/path?a=1&b=2", [("y", "y"), ("z", "z")])
self.assertEqual(url, "https://localhost/path?a=1&b=2&y=y&z=z")
def test_url_concat_no_params(self):
url = url_concat("https://localhost/path?r=1&t=2", [])
self.assertEqual(url, "https://localhost/path?r=1&t=2")
def test_url_concat_none_params(self):
url = url_concat("https://localhost/path?r=1&t=2", None)
self.assertEqual(url, "https://localhost/path?r=1&t=2")
def test_url_concat_with_frag(self):
url = url_concat("https://localhost/path#tab", [("y", "y")])
self.assertEqual(url, "https://localhost/path?y=y#tab")
def test_url_concat_multi_same_params(self):
url = url_concat("https://localhost/path", [("y", "y1"), ("y", "y2")])
self.assertEqual(url, "https://localhost/path?y=y1&y=y2")
def test_url_concat_multi_same_query_params(self):
url = url_concat("https://localhost/path?r=1&r=2", [("y", "y")])
self.assertEqual(url, "https://localhost/path?r=1&r=2&y=y")
def test_url_concat_dict_params(self):
url = url_concat("https://localhost/path", dict(y="y"))
self.assertEqual(url, "https://localhost/path?y=y")
class QsParseTest(unittest.TestCase):
def test_parsing(self):
qsstring = "a=1&b=2&a=3"
qs = urllib.parse.parse_qs(qsstring)
qsl = list(qs_to_qsl(qs))
self.assertIn(("a", "1"), qsl)
self.assertIn(("a", "3"), qsl)
self.assertIn(("b", "2"), qsl)
class MultipartFormDataTest(unittest.TestCase):
def test_file_upload(self):
data = b"""\
--1234
Content-Disposition: form-data; name="files"; filename="ab.txt"
Foo
--1234--""".replace(
b"\n", b"\r\n"
)
args, files = form_data_args()
parse_multipart_form_data(b"1234", data, args, files)
file = files["files"][0]
self.assertEqual(file["filename"], "ab.txt")
self.assertEqual(file["body"], b"Foo")
def test_unquoted_names(self):
# quotes are optional unless special characters are present
data = b"""\
--1234
Content-Disposition: form-data; name=files; filename=ab.txt
Foo
--1234--""".replace(
b"\n", b"\r\n"
)
args, files = form_data_args()
parse_multipart_form_data(b"1234", data, args, files)
file = files["files"][0]
self.assertEqual(file["filename"], "ab.txt")
self.assertEqual(file["body"], b"Foo")
def test_special_filenames(self):
filenames = [
"a;b.txt",
'a"b.txt',
'a";b.txt',
'a;"b.txt',
'a";";.txt',
'a\\"b.txt',
"a\\b.txt",
]
for filename in filenames:
logging.debug("trying filename %r", filename)
str_data = """\
--1234
Content-Disposition: form-data; name="files"; filename="%s"
Foo
--1234--""" % filename.replace(
"\\", "\\\\"
).replace(
'"', '\\"'
)
data = utf8(str_data.replace("\n", "\r\n"))
args, files = form_data_args()
parse_multipart_form_data(b"1234", data, args, files)
file = files["files"][0]
self.assertEqual(file["filename"], filename)
self.assertEqual(file["body"], b"Foo")
def test_non_ascii_filename(self):
data = b"""\
--1234
Content-Disposition: form-data; name="files"; filename="ab.txt"; filename*=UTF-8''%C3%A1b.txt
Foo
--1234--""".replace(
b"\n", b"\r\n"
)
args, files = form_data_args()
parse_multipart_form_data(b"1234", data, args, files)
file = files["files"][0]
self.assertEqual(file["filename"], u"áb.txt")
self.assertEqual(file["body"], b"Foo")
def test_boundary_starts_and_ends_with_quotes(self):
data = b"""\
--1234
Content-Disposition: form-data; name="files"; filename="ab.txt"
Foo
--1234--""".replace(
b"\n", b"\r\n"
)
args, files = form_data_args()
parse_multipart_form_data(b'"1234"', data, args, files)
file = files["files"][0]
self.assertEqual(file["filename"], "ab.txt")
self.assertEqual(file["body"], b"Foo")
def test_missing_headers(self):
data = b"""\
--1234
Foo
--1234--""".replace(
b"\n", b"\r\n"
)
args, files = form_data_args()
with ExpectLog(gen_log, "multipart/form-data missing headers"):
parse_multipart_form_data(b"1234", data, args, files)
self.assertEqual(files, {})
def test_invalid_content_disposition(self):
data = b"""\
--1234
Content-Disposition: invalid; name="files"; filename="ab.txt"
Foo
--1234--""".replace(
b"\n", b"\r\n"
)
args, files = form_data_args()
with ExpectLog(gen_log, "Invalid multipart/form-data"):
parse_multipart_form_data(b"1234", data, args, files)
self.assertEqual(files, {})
def test_line_does_not_end_with_correct_line_break(self):
data = b"""\
--1234
Content-Disposition: form-data; name="files"; filename="ab.txt"
Foo--1234--""".replace(
b"\n", b"\r\n"
)
args, files = form_data_args()
with ExpectLog(gen_log, "Invalid multipart/form-data"):
parse_multipart_form_data(b"1234", data, args, files)
self.assertEqual(files, {})
def test_content_disposition_header_without_name_parameter(self):
data = b"""\
--1234
Content-Disposition: form-data; filename="ab.txt"
Foo
--1234--""".replace(
b"\n", b"\r\n"
)
args, files = form_data_args()
with ExpectLog(gen_log, "multipart/form-data value missing name"):
parse_multipart_form_data(b"1234", data, args, files)
self.assertEqual(files, {})
def test_data_after_final_boundary(self):
# The spec requires that data after the final boundary be ignored.
# http://www.w3.org/Protocols/rfc1341/7_2_Multipart.html
# In practice, some libraries include an extra CRLF after the boundary.
data = b"""\
--1234
Content-Disposition: form-data; name="files"; filename="ab.txt"
Foo
--1234--
""".replace(
b"\n", b"\r\n"
)
args, files = form_data_args()
parse_multipart_form_data(b"1234", data, args, files)
file = files["files"][0]
self.assertEqual(file["filename"], "ab.txt")
self.assertEqual(file["body"], b"Foo")
class HTTPHeadersTest(unittest.TestCase):
def test_multi_line(self):
# Lines beginning with whitespace are appended to the previous line
# with any leading whitespace replaced by a single space.
# Note that while multi-line headers are a part of the HTTP spec,
# their use is strongly discouraged.
data = """\
Foo: bar
baz
Asdf: qwer
\tzxcv
Foo: even
more
lines
""".replace(
"\n", "\r\n"
)
headers = HTTPHeaders.parse(data)
self.assertEqual(headers["asdf"], "qwer zxcv")
self.assertEqual(headers.get_list("asdf"), ["qwer zxcv"])
self.assertEqual(headers["Foo"], "bar baz,even more lines")
self.assertEqual(headers.get_list("foo"), ["bar baz", "even more lines"])
self.assertEqual(
sorted(list(headers.get_all())),
[("Asdf", "qwer zxcv"), ("Foo", "bar baz"), ("Foo", "even more lines")],
)
def test_malformed_continuation(self):
# If the first line starts with whitespace, it's a
# continuation line with nothing to continue, so reject it
# (with a proper error).
data = " Foo: bar"
self.assertRaises(HTTPInputError, HTTPHeaders.parse, data)
def test_unicode_newlines(self):
# Ensure that only \r\n is recognized as a header separator, and not
# the other newline-like unicode characters.
# Characters that are likely to be problematic can be found in
# http://unicode.org/standard/reports/tr13/tr13-5.html
# and cpython's unicodeobject.c (which defines the implementation
# of unicode_type.splitlines(), and uses a different list than TR13).
newlines = [
u"\u001b", # VERTICAL TAB
u"\u001c", # FILE SEPARATOR
u"\u001d", # GROUP SEPARATOR
u"\u001e", # RECORD SEPARATOR
u"\u0085", # NEXT LINE
u"\u2028", # LINE SEPARATOR
u"\u2029", # PARAGRAPH SEPARATOR
]
for newline in newlines:
# Try the utf8 and latin1 representations of each newline
for encoding in ["utf8", "latin1"]:
try:
try:
encoded = newline.encode(encoding)
except UnicodeEncodeError:
# Some chars cannot be represented in latin1
continue
data = b"Cookie: foo=" + encoded + b"bar"
# parse() wants a native_str, so decode through latin1
# in the same way the real parser does.
headers = HTTPHeaders.parse(native_str(data.decode("latin1")))
expected = [
(
"Cookie",
"foo=" + native_str(encoded.decode("latin1")) + "bar",
)
]
self.assertEqual(expected, list(headers.get_all()))
except Exception:
gen_log.warning("failed while trying %r in %s", newline, encoding)
raise
def test_optional_cr(self):
# Both CRLF and LF should be accepted as separators. CR should not be
# part of the data when followed by LF, but it is a normal char
# otherwise (or should bare CR be an error?)
headers = HTTPHeaders.parse("CRLF: crlf\r\nLF: lf\nCR: cr\rMore: more\r\n")
self.assertEqual(
sorted(headers.get_all()),
[("Cr", "cr\rMore: more"), ("Crlf", "crlf"), ("Lf", "lf")],
)
def test_copy(self):
all_pairs = [("A", "1"), ("A", "2"), ("B", "c")]
h1 = HTTPHeaders()
for k, v in all_pairs:
h1.add(k, v)
h2 = h1.copy()
h3 = copy.copy(h1)
h4 = copy.deepcopy(h1)
for headers in [h1, h2, h3, h4]:
# All the copies are identical, no matter how they were
# constructed.
self.assertEqual(list(sorted(headers.get_all())), all_pairs)
for headers in [h2, h3, h4]:
# Neither the dict or its member lists are reused.
self.assertIsNot(headers, h1)
self.assertIsNot(headers.get_list("A"), h1.get_list("A"))
def test_pickle_roundtrip(self):
headers = HTTPHeaders()
headers.add("Set-Cookie", "a=b")
headers.add("Set-Cookie", "c=d")
headers.add("Content-Type", "text/html")
pickled = pickle.dumps(headers)
unpickled = pickle.loads(pickled)
self.assertEqual(sorted(headers.get_all()), sorted(unpickled.get_all()))
self.assertEqual(sorted(headers.items()), sorted(unpickled.items()))
def test_setdefault(self):
headers = HTTPHeaders()
headers["foo"] = "bar"
# If a value is present, setdefault returns it without changes.
self.assertEqual(headers.setdefault("foo", "baz"), "bar")
self.assertEqual(headers["foo"], "bar")
# If a value is not present, setdefault sets it for future use.
self.assertEqual(headers.setdefault("quux", "xyzzy"), "xyzzy")
self.assertEqual(headers["quux"], "xyzzy")
self.assertEqual(sorted(headers.get_all()), [("Foo", "bar"), ("Quux", "xyzzy")])
def test_string(self):
headers = HTTPHeaders()
headers.add("Foo", "1")
headers.add("Foo", "2")
headers.add("Foo", "3")
headers2 = HTTPHeaders.parse(str(headers))
self.assertEquals(headers, headers2)
class FormatTimestampTest(unittest.TestCase):
# Make sure that all the input types are supported.
TIMESTAMP = 1359312200.503611
EXPECTED = "Sun, 27 Jan 2013 18:43:20 GMT"
def check(self, value):
self.assertEqual(format_timestamp(value), self.EXPECTED)
def test_unix_time_float(self):
self.check(self.TIMESTAMP)
def test_unix_time_int(self):
self.check(int(self.TIMESTAMP))
def test_struct_time(self):
self.check(time.gmtime(self.TIMESTAMP))
def test_time_tuple(self):
tup = tuple(time.gmtime(self.TIMESTAMP))
self.assertEqual(9, len(tup))
self.check(tup)
def test_datetime(self):
self.check(datetime.datetime.utcfromtimestamp(self.TIMESTAMP))
# HTTPServerRequest is mainly tested incidentally to the server itself,
# but this tests the parts of the class that can be tested in isolation.
class HTTPServerRequestTest(unittest.TestCase):
def test_default_constructor(self):
# All parameters are formally optional, but uri is required
# (and has been for some time). This test ensures that no
# more required parameters slip in.
HTTPServerRequest(uri="/")
def test_body_is_a_byte_string(self):
requets = HTTPServerRequest(uri="/")
self.assertIsInstance(requets.body, bytes)
def test_repr_does_not_contain_headers(self):
request = HTTPServerRequest(
uri="/", headers=HTTPHeaders({"Canary": ["Coal Mine"]})
)
self.assertTrue("Canary" not in repr(request))
class ParseRequestStartLineTest(unittest.TestCase):
METHOD = "GET"
PATH = "/foo"
VERSION = "HTTP/1.1"
def test_parse_request_start_line(self):
start_line = " ".join([self.METHOD, self.PATH, self.VERSION])
parsed_start_line = parse_request_start_line(start_line)
self.assertEqual(parsed_start_line.method, self.METHOD)
self.assertEqual(parsed_start_line.path, self.PATH)
self.assertEqual(parsed_start_line.version, self.VERSION)
class ParseCookieTest(unittest.TestCase):
# These tests copied from Django:
# https://github.com/django/django/pull/6277/commits/da810901ada1cae9fc1f018f879f11a7fb467b28
def test_python_cookies(self):
"""
Test cases copied from Python's Lib/test/test_http_cookies.py
"""
self.assertEqual(
parse_cookie("chips=ahoy; vienna=finger"),
{"chips": "ahoy", "vienna": "finger"},
)
# Here parse_cookie() differs from Python's cookie parsing in that it
# treats all semicolons as delimiters, even within quotes.
self.assertEqual(
parse_cookie('keebler="E=mc2; L=\\"Loves\\"; fudge=\\012;"'),
{"keebler": '"E=mc2', "L": '\\"Loves\\"', "fudge": "\\012", "": '"'},
)
# Illegal cookies that have an '=' char in an unquoted value.
self.assertEqual(parse_cookie("keebler=E=mc2"), {"keebler": "E=mc2"})
# Cookies with ':' character in their name.
self.assertEqual(
parse_cookie("key:term=value:term"), {"key:term": "value:term"}
)
# Cookies with '[' and ']'.
self.assertEqual(
parse_cookie("a=b; c=[; d=r; f=h"), {"a": "b", "c": "[", "d": "r", "f": "h"}
)
def test_cookie_edgecases(self):
# Cookies that RFC6265 allows.
self.assertEqual(
parse_cookie("a=b; Domain=example.com"), {"a": "b", "Domain": "example.com"}
)
# parse_cookie() has historically kept only the last cookie with the
# same name.
self.assertEqual(parse_cookie("a=b; h=i; a=c"), {"a": "c", "h": "i"})
def test_invalid_cookies(self):
"""
Cookie strings that go against RFC6265 but browsers will send if set
via document.cookie.
"""
# Chunks without an equals sign appear as unnamed values per
# https://bugzilla.mozilla.org/show_bug.cgi?id=169091
self.assertIn(
"django_language",
parse_cookie("abc=def; unnamed; django_language=en").keys(),
)
# Even a double quote may be an unamed value.
self.assertEqual(parse_cookie('a=b; "; c=d'), {"a": "b", "": '"', "c": "d"})
# Spaces in names and values, and an equals sign in values.
self.assertEqual(
parse_cookie("a b c=d e = f; gh=i"), {"a b c": "d e = f", "gh": "i"}
)
# More characters the spec forbids.
self.assertEqual(
parse_cookie('a b,c<>@:/[]?{}=d " =e,f g'),
{"a b,c<>@:/[]?{}": 'd " =e,f g'},
)
# Unicode characters. The spec only allows ASCII.
self.assertEqual(
parse_cookie("saint=André Bessette"),
{"saint": native_str("André Bessette")},
)
# Browsers don't send extra whitespace or semicolons in Cookie headers,
# but parse_cookie() should parse whitespace the same way
# document.cookie parses whitespace.
self.assertEqual(
parse_cookie(" = b ; ; = ; c = ; "), {"": "b", "c": ""}
)

View File

@@ -0,0 +1,66 @@
# flake8: noqa
import subprocess
import sys
import unittest
_import_everything = b"""
# The event loop is not fork-safe, and it's easy to initialize an asyncio.Future
# at startup, which in turn creates the default event loop and prevents forking.
# Explicitly disallow the default event loop so that an error will be raised
# if something tries to touch it.
import asyncio
asyncio.set_event_loop(None)
import tornado.auth
import tornado.autoreload
import tornado.concurrent
import tornado.escape
import tornado.gen
import tornado.http1connection
import tornado.httpclient
import tornado.httpserver
import tornado.httputil
import tornado.ioloop
import tornado.iostream
import tornado.locale
import tornado.log
import tornado.netutil
import tornado.options
import tornado.process
import tornado.simple_httpclient
import tornado.tcpserver
import tornado.tcpclient
import tornado.template
import tornado.testing
import tornado.util
import tornado.web
import tornado.websocket
import tornado.wsgi
try:
import pycurl
except ImportError:
pass
else:
import tornado.curl_httpclient
"""
class ImportTest(unittest.TestCase):
def test_import_everything(self):
# Test that all Tornado modules can be imported without side effects,
# specifically without initializing the default asyncio event loop.
# Since we can't tell which modules may have already beein imported
# in our process, do it in a subprocess for a clean slate.
proc = subprocess.Popen([sys.executable], stdin=subprocess.PIPE)
proc.communicate(_import_everything)
self.assertEqual(proc.returncode, 0)
def test_import_aliases(self):
# Ensure we don't delete formerly-documented aliases accidentally.
import tornado.ioloop
import tornado.gen
import tornado.util
self.assertIs(tornado.ioloop.TimeoutError, tornado.util.TimeoutError)
self.assertIs(tornado.gen.TimeoutError, tornado.util.TimeoutError)

View File

@@ -0,0 +1,719 @@
from concurrent.futures import ThreadPoolExecutor
from concurrent import futures
import contextlib
import datetime
import functools
import socket
import subprocess
import sys
import threading
import time
import types
from unittest import mock
import unittest
from tornado.escape import native_str
from tornado import gen
from tornado.ioloop import IOLoop, TimeoutError, PeriodicCallback
from tornado.log import app_log
from tornado.testing import AsyncTestCase, bind_unused_port, ExpectLog, gen_test
from tornado.test.util import skipIfNonUnix, skipOnTravis
import typing
if typing.TYPE_CHECKING:
from typing import List # noqa: F401
class TestIOLoop(AsyncTestCase):
def test_add_callback_return_sequence(self):
# A callback returning {} or [] shouldn't spin the CPU, see Issue #1803.
self.calls = 0
loop = self.io_loop
test = self
old_add_callback = loop.add_callback
def add_callback(self, callback, *args, **kwargs):
test.calls += 1
old_add_callback(callback, *args, **kwargs)
loop.add_callback = types.MethodType(add_callback, loop)
loop.add_callback(lambda: {})
loop.add_callback(lambda: [])
loop.add_timeout(datetime.timedelta(milliseconds=50), loop.stop)
loop.start()
self.assertLess(self.calls, 10)
@skipOnTravis
def test_add_callback_wakeup(self):
# Make sure that add_callback from inside a running IOLoop
# wakes up the IOLoop immediately instead of waiting for a timeout.
def callback():
self.called = True
self.stop()
def schedule_callback():
self.called = False
self.io_loop.add_callback(callback)
# Store away the time so we can check if we woke up immediately
self.start_time = time.time()
self.io_loop.add_timeout(self.io_loop.time(), schedule_callback)
self.wait()
self.assertAlmostEqual(time.time(), self.start_time, places=2)
self.assertTrue(self.called)
@skipOnTravis
def test_add_callback_wakeup_other_thread(self):
def target():
# sleep a bit to let the ioloop go into its poll loop
time.sleep(0.01)
self.stop_time = time.time()
self.io_loop.add_callback(self.stop)
thread = threading.Thread(target=target)
self.io_loop.add_callback(thread.start)
self.wait()
delta = time.time() - self.stop_time
self.assertLess(delta, 0.1)
thread.join()
def test_add_timeout_timedelta(self):
self.io_loop.add_timeout(datetime.timedelta(microseconds=1), self.stop)
self.wait()
def test_multiple_add(self):
sock, port = bind_unused_port()
try:
self.io_loop.add_handler(
sock.fileno(), lambda fd, events: None, IOLoop.READ
)
# Attempting to add the same handler twice fails
# (with a platform-dependent exception)
self.assertRaises(
Exception,
self.io_loop.add_handler,
sock.fileno(),
lambda fd, events: None,
IOLoop.READ,
)
finally:
self.io_loop.remove_handler(sock.fileno())
sock.close()
def test_remove_without_add(self):
# remove_handler should not throw an exception if called on an fd
# was never added.
sock, port = bind_unused_port()
try:
self.io_loop.remove_handler(sock.fileno())
finally:
sock.close()
def test_add_callback_from_signal(self):
# cheat a little bit and just run this normally, since we can't
# easily simulate the races that happen with real signal handlers
self.io_loop.add_callback_from_signal(self.stop)
self.wait()
def test_add_callback_from_signal_other_thread(self):
# Very crude test, just to make sure that we cover this case.
# This also happens to be the first test where we run an IOLoop in
# a non-main thread.
other_ioloop = IOLoop()
thread = threading.Thread(target=other_ioloop.start)
thread.start()
other_ioloop.add_callback_from_signal(other_ioloop.stop)
thread.join()
other_ioloop.close()
def test_add_callback_while_closing(self):
# add_callback should not fail if it races with another thread
# closing the IOLoop. The callbacks are dropped silently
# without executing.
closing = threading.Event()
def target():
other_ioloop.add_callback(other_ioloop.stop)
other_ioloop.start()
closing.set()
other_ioloop.close(all_fds=True)
other_ioloop = IOLoop()
thread = threading.Thread(target=target)
thread.start()
closing.wait()
for i in range(1000):
other_ioloop.add_callback(lambda: None)
@skipIfNonUnix # just because socketpair is so convenient
def test_read_while_writeable(self):
# Ensure that write events don't come in while we're waiting for
# a read and haven't asked for writeability. (the reverse is
# difficult to test for)
client, server = socket.socketpair()
try:
def handler(fd, events):
self.assertEqual(events, IOLoop.READ)
self.stop()
self.io_loop.add_handler(client.fileno(), handler, IOLoop.READ)
self.io_loop.add_timeout(
self.io_loop.time() + 0.01, functools.partial(server.send, b"asdf")
)
self.wait()
self.io_loop.remove_handler(client.fileno())
finally:
client.close()
server.close()
def test_remove_timeout_after_fire(self):
# It is not an error to call remove_timeout after it has run.
handle = self.io_loop.add_timeout(self.io_loop.time(), self.stop)
self.wait()
self.io_loop.remove_timeout(handle)
def test_remove_timeout_cleanup(self):
# Add and remove enough callbacks to trigger cleanup.
# Not a very thorough test, but it ensures that the cleanup code
# gets executed and doesn't blow up. This test is only really useful
# on PollIOLoop subclasses, but it should run silently on any
# implementation.
for i in range(2000):
timeout = self.io_loop.add_timeout(self.io_loop.time() + 3600, lambda: None)
self.io_loop.remove_timeout(timeout)
# HACK: wait two IOLoop iterations for the GC to happen.
self.io_loop.add_callback(lambda: self.io_loop.add_callback(self.stop))
self.wait()
def test_remove_timeout_from_timeout(self):
calls = [False, False]
# Schedule several callbacks and wait for them all to come due at once.
# t2 should be cancelled by t1, even though it is already scheduled to
# be run before the ioloop even looks at it.
now = self.io_loop.time()
def t1():
calls[0] = True
self.io_loop.remove_timeout(t2_handle)
self.io_loop.add_timeout(now + 0.01, t1)
def t2():
calls[1] = True
t2_handle = self.io_loop.add_timeout(now + 0.02, t2)
self.io_loop.add_timeout(now + 0.03, self.stop)
time.sleep(0.03)
self.wait()
self.assertEqual(calls, [True, False])
def test_timeout_with_arguments(self):
# This tests that all the timeout methods pass through *args correctly.
results = [] # type: List[int]
self.io_loop.add_timeout(self.io_loop.time(), results.append, 1)
self.io_loop.add_timeout(datetime.timedelta(seconds=0), results.append, 2)
self.io_loop.call_at(self.io_loop.time(), results.append, 3)
self.io_loop.call_later(0, results.append, 4)
self.io_loop.call_later(0, self.stop)
self.wait()
# The asyncio event loop does not guarantee the order of these
# callbacks.
self.assertEqual(sorted(results), [1, 2, 3, 4])
def test_add_timeout_return(self):
# All the timeout methods return non-None handles that can be
# passed to remove_timeout.
handle = self.io_loop.add_timeout(self.io_loop.time(), lambda: None)
self.assertFalse(handle is None)
self.io_loop.remove_timeout(handle)
def test_call_at_return(self):
handle = self.io_loop.call_at(self.io_loop.time(), lambda: None)
self.assertFalse(handle is None)
self.io_loop.remove_timeout(handle)
def test_call_later_return(self):
handle = self.io_loop.call_later(0, lambda: None)
self.assertFalse(handle is None)
self.io_loop.remove_timeout(handle)
def test_close_file_object(self):
"""When a file object is used instead of a numeric file descriptor,
the object should be closed (by IOLoop.close(all_fds=True),
not just the fd.
"""
# Use a socket since they are supported by IOLoop on all platforms.
# Unfortunately, sockets don't support the .closed attribute for
# inspecting their close status, so we must use a wrapper.
class SocketWrapper(object):
def __init__(self, sockobj):
self.sockobj = sockobj
self.closed = False
def fileno(self):
return self.sockobj.fileno()
def close(self):
self.closed = True
self.sockobj.close()
sockobj, port = bind_unused_port()
socket_wrapper = SocketWrapper(sockobj)
io_loop = IOLoop()
io_loop.add_handler(socket_wrapper, lambda fd, events: None, IOLoop.READ)
io_loop.close(all_fds=True)
self.assertTrue(socket_wrapper.closed)
def test_handler_callback_file_object(self):
"""The handler callback receives the same fd object it passed in."""
server_sock, port = bind_unused_port()
fds = []
def handle_connection(fd, events):
fds.append(fd)
conn, addr = server_sock.accept()
conn.close()
self.stop()
self.io_loop.add_handler(server_sock, handle_connection, IOLoop.READ)
with contextlib.closing(socket.socket()) as client_sock:
client_sock.connect(("127.0.0.1", port))
self.wait()
self.io_loop.remove_handler(server_sock)
self.io_loop.add_handler(server_sock.fileno(), handle_connection, IOLoop.READ)
with contextlib.closing(socket.socket()) as client_sock:
client_sock.connect(("127.0.0.1", port))
self.wait()
self.assertIs(fds[0], server_sock)
self.assertEqual(fds[1], server_sock.fileno())
self.io_loop.remove_handler(server_sock.fileno())
server_sock.close()
def test_mixed_fd_fileobj(self):
server_sock, port = bind_unused_port()
def f(fd, events):
pass
self.io_loop.add_handler(server_sock, f, IOLoop.READ)
with self.assertRaises(Exception):
# The exact error is unspecified - some implementations use
# IOError, others use ValueError.
self.io_loop.add_handler(server_sock.fileno(), f, IOLoop.READ)
self.io_loop.remove_handler(server_sock.fileno())
server_sock.close()
def test_reentrant(self):
"""Calling start() twice should raise an error, not deadlock."""
returned_from_start = [False]
got_exception = [False]
def callback():
try:
self.io_loop.start()
returned_from_start[0] = True
except Exception:
got_exception[0] = True
self.stop()
self.io_loop.add_callback(callback)
self.wait()
self.assertTrue(got_exception[0])
self.assertFalse(returned_from_start[0])
def test_exception_logging(self):
"""Uncaught exceptions get logged by the IOLoop."""
self.io_loop.add_callback(lambda: 1 / 0)
self.io_loop.add_callback(self.stop)
with ExpectLog(app_log, "Exception in callback"):
self.wait()
def test_exception_logging_future(self):
"""The IOLoop examines exceptions from Futures and logs them."""
@gen.coroutine
def callback():
self.io_loop.add_callback(self.stop)
1 / 0
self.io_loop.add_callback(callback)
with ExpectLog(app_log, "Exception in callback"):
self.wait()
def test_exception_logging_native_coro(self):
"""The IOLoop examines exceptions from awaitables and logs them."""
async def callback():
# Stop the IOLoop two iterations after raising an exception
# to give the exception time to be logged.
self.io_loop.add_callback(self.io_loop.add_callback, self.stop)
1 / 0
self.io_loop.add_callback(callback)
with ExpectLog(app_log, "Exception in callback"):
self.wait()
def test_spawn_callback(self):
# Both add_callback and spawn_callback run directly on the IOLoop,
# so their errors are logged without stopping the test.
self.io_loop.add_callback(lambda: 1 / 0)
self.io_loop.add_callback(self.stop)
with ExpectLog(app_log, "Exception in callback"):
self.wait()
# A spawned callback is run directly on the IOLoop, so it will be
# logged without stopping the test.
self.io_loop.spawn_callback(lambda: 1 / 0)
self.io_loop.add_callback(self.stop)
with ExpectLog(app_log, "Exception in callback"):
self.wait()
@skipIfNonUnix
def test_remove_handler_from_handler(self):
# Create two sockets with simultaneous read events.
client, server = socket.socketpair()
try:
client.send(b"abc")
server.send(b"abc")
# After reading from one fd, remove the other from the IOLoop.
chunks = []
def handle_read(fd, events):
chunks.append(fd.recv(1024))
if fd is client:
self.io_loop.remove_handler(server)
else:
self.io_loop.remove_handler(client)
self.io_loop.add_handler(client, handle_read, self.io_loop.READ)
self.io_loop.add_handler(server, handle_read, self.io_loop.READ)
self.io_loop.call_later(0.1, self.stop)
self.wait()
# Only one fd was read; the other was cleanly removed.
self.assertEqual(chunks, [b"abc"])
finally:
client.close()
server.close()
@gen_test
def test_init_close_race(self):
# Regression test for #2367
def f():
for i in range(10):
loop = IOLoop()
loop.close()
yield gen.multi([self.io_loop.run_in_executor(None, f) for i in range(2)])
# Deliberately not a subclass of AsyncTestCase so the IOLoop isn't
# automatically set as current.
class TestIOLoopCurrent(unittest.TestCase):
def setUp(self):
self.io_loop = None
IOLoop.clear_current()
def tearDown(self):
if self.io_loop is not None:
self.io_loop.close()
def test_default_current(self):
self.io_loop = IOLoop()
# The first IOLoop with default arguments is made current.
self.assertIs(self.io_loop, IOLoop.current())
# A second IOLoop can be created but is not made current.
io_loop2 = IOLoop()
self.assertIs(self.io_loop, IOLoop.current())
io_loop2.close()
def test_non_current(self):
self.io_loop = IOLoop(make_current=False)
# The new IOLoop is not initially made current.
self.assertIsNone(IOLoop.current(instance=False))
# Starting the IOLoop makes it current, and stopping the loop
# makes it non-current. This process is repeatable.
for i in range(3):
def f():
self.current_io_loop = IOLoop.current()
self.io_loop.stop()
self.io_loop.add_callback(f)
self.io_loop.start()
self.assertIs(self.current_io_loop, self.io_loop)
# Now that the loop is stopped, it is no longer current.
self.assertIsNone(IOLoop.current(instance=False))
def test_force_current(self):
self.io_loop = IOLoop(make_current=True)
self.assertIs(self.io_loop, IOLoop.current())
with self.assertRaises(RuntimeError):
# A second make_current=True construction cannot succeed.
IOLoop(make_current=True)
# current() was not affected by the failed construction.
self.assertIs(self.io_loop, IOLoop.current())
class TestIOLoopCurrentAsync(AsyncTestCase):
@gen_test
def test_clear_without_current(self):
# If there is no current IOLoop, clear_current is a no-op (but
# should not fail). Use a thread so we see the threading.Local
# in a pristine state.
with ThreadPoolExecutor(1) as e:
yield e.submit(IOLoop.clear_current)
class TestIOLoopFutures(AsyncTestCase):
def test_add_future_threads(self):
with futures.ThreadPoolExecutor(1) as pool:
def dummy():
pass
self.io_loop.add_future(
pool.submit(dummy), lambda future: self.stop(future)
)
future = self.wait()
self.assertTrue(future.done())
self.assertTrue(future.result() is None)
@gen_test
def test_run_in_executor_gen(self):
event1 = threading.Event()
event2 = threading.Event()
def sync_func(self_event, other_event):
self_event.set()
other_event.wait()
# Note that return value doesn't actually do anything,
# it is just passed through to our final assertion to
# make sure it is passed through properly.
return self_event
# Run two synchronous functions, which would deadlock if not
# run in parallel.
res = yield [
IOLoop.current().run_in_executor(None, sync_func, event1, event2),
IOLoop.current().run_in_executor(None, sync_func, event2, event1),
]
self.assertEqual([event1, event2], res)
@gen_test
def test_run_in_executor_native(self):
event1 = threading.Event()
event2 = threading.Event()
def sync_func(self_event, other_event):
self_event.set()
other_event.wait()
return self_event
# Go through an async wrapper to ensure that the result of
# run_in_executor works with await and not just gen.coroutine
# (simply passing the underlying concurrrent future would do that).
async def async_wrapper(self_event, other_event):
return await IOLoop.current().run_in_executor(
None, sync_func, self_event, other_event
)
res = yield [async_wrapper(event1, event2), async_wrapper(event2, event1)]
self.assertEqual([event1, event2], res)
@gen_test
def test_set_default_executor(self):
count = [0]
class MyExecutor(futures.ThreadPoolExecutor):
def submit(self, func, *args):
count[0] += 1
return super(MyExecutor, self).submit(func, *args)
event = threading.Event()
def sync_func():
event.set()
executor = MyExecutor(1)
loop = IOLoop.current()
loop.set_default_executor(executor)
yield loop.run_in_executor(None, sync_func)
self.assertEqual(1, count[0])
self.assertTrue(event.is_set())
class TestIOLoopRunSync(unittest.TestCase):
def setUp(self):
self.io_loop = IOLoop()
def tearDown(self):
self.io_loop.close()
def test_sync_result(self):
with self.assertRaises(gen.BadYieldError):
self.io_loop.run_sync(lambda: 42)
def test_sync_exception(self):
with self.assertRaises(ZeroDivisionError):
self.io_loop.run_sync(lambda: 1 / 0)
def test_async_result(self):
@gen.coroutine
def f():
yield gen.moment
raise gen.Return(42)
self.assertEqual(self.io_loop.run_sync(f), 42)
def test_async_exception(self):
@gen.coroutine
def f():
yield gen.moment
1 / 0
with self.assertRaises(ZeroDivisionError):
self.io_loop.run_sync(f)
def test_current(self):
def f():
self.assertIs(IOLoop.current(), self.io_loop)
self.io_loop.run_sync(f)
def test_timeout(self):
@gen.coroutine
def f():
yield gen.sleep(1)
self.assertRaises(TimeoutError, self.io_loop.run_sync, f, timeout=0.01)
def test_native_coroutine(self):
@gen.coroutine
def f1():
yield gen.moment
async def f2():
await f1()
self.io_loop.run_sync(f2)
class TestPeriodicCallbackMath(unittest.TestCase):
def simulate_calls(self, pc, durations):
"""Simulate a series of calls to the PeriodicCallback.
Pass a list of call durations in seconds (negative values
work to simulate clock adjustments during the call, or more or
less equivalently, between calls). This method returns the
times at which each call would be made.
"""
calls = []
now = 1000
pc._next_timeout = now
for d in durations:
pc._update_next(now)
calls.append(pc._next_timeout)
now = pc._next_timeout + d
return calls
def dummy(self):
pass
def test_basic(self):
pc = PeriodicCallback(self.dummy, 10000)
self.assertEqual(
self.simulate_calls(pc, [0] * 5), [1010, 1020, 1030, 1040, 1050]
)
def test_overrun(self):
# If a call runs for too long, we skip entire cycles to get
# back on schedule.
call_durations = [9, 9, 10, 11, 20, 20, 35, 35, 0, 0, 0]
expected = [
1010,
1020,
1030, # first 3 calls on schedule
1050,
1070, # next 2 delayed one cycle
1100,
1130, # next 2 delayed 2 cycles
1170,
1210, # next 2 delayed 3 cycles
1220,
1230, # then back on schedule.
]
pc = PeriodicCallback(self.dummy, 10000)
self.assertEqual(self.simulate_calls(pc, call_durations), expected)
def test_clock_backwards(self):
pc = PeriodicCallback(self.dummy, 10000)
# Backwards jumps are ignored, potentially resulting in a
# slightly slow schedule (although we assume that when
# time.time() and time.monotonic() are different, time.time()
# is getting adjusted by NTP and is therefore more accurate)
self.assertEqual(
self.simulate_calls(pc, [-2, -1, -3, -2, 0]), [1010, 1020, 1030, 1040, 1050]
)
# For big jumps, we should perhaps alter the schedule, but we
# don't currently. This trace shows that we run callbacks
# every 10s of time.time(), but the first and second calls are
# 110s of real time apart because the backwards jump is
# ignored.
self.assertEqual(self.simulate_calls(pc, [-100, 0, 0]), [1010, 1020, 1030])
def test_jitter(self):
random_times = [0.5, 1, 0, 0.75]
expected = [1010, 1022.5, 1030, 1041.25]
call_durations = [0] * len(random_times)
pc = PeriodicCallback(self.dummy, 10000, jitter=0.5)
def mock_random():
return random_times.pop(0)
with mock.patch("random.random", mock_random):
self.assertEqual(self.simulate_calls(pc, call_durations), expected)
class TestIOLoopConfiguration(unittest.TestCase):
def run_python(self, *statements):
stmt_list = [
"from tornado.ioloop import IOLoop",
"classname = lambda x: x.__class__.__name__",
] + list(statements)
args = [sys.executable, "-c", "; ".join(stmt_list)]
return native_str(subprocess.check_output(args)).strip()
def test_default(self):
# When asyncio is available, it is used by default.
cls = self.run_python("print(classname(IOLoop.current()))")
self.assertEqual(cls, "AsyncIOMainLoop")
cls = self.run_python("print(classname(IOLoop()))")
self.assertEqual(cls, "AsyncIOLoop")
def test_asyncio(self):
cls = self.run_python(
'IOLoop.configure("tornado.platform.asyncio.AsyncIOLoop")',
"print(classname(IOLoop.current()))",
)
self.assertEqual(cls, "AsyncIOMainLoop")
def test_asyncio_main(self):
cls = self.run_python(
"from tornado.platform.asyncio import AsyncIOMainLoop",
"AsyncIOMainLoop().install()",
"print(classname(IOLoop.current()))",
)
self.assertEqual(cls, "AsyncIOMainLoop")
if __name__ == "__main__":
unittest.main()

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,151 @@
import datetime
import os
import shutil
import tempfile
import unittest
import tornado.locale
from tornado.escape import utf8, to_unicode
from tornado.util import unicode_type
class TranslationLoaderTest(unittest.TestCase):
# TODO: less hacky way to get isolated tests
SAVE_VARS = ["_translations", "_supported_locales", "_use_gettext"]
def clear_locale_cache(self):
tornado.locale.Locale._cache = {}
def setUp(self):
self.saved = {} # type: dict
for var in TranslationLoaderTest.SAVE_VARS:
self.saved[var] = getattr(tornado.locale, var)
self.clear_locale_cache()
def tearDown(self):
for k, v in self.saved.items():
setattr(tornado.locale, k, v)
self.clear_locale_cache()
def test_csv(self):
tornado.locale.load_translations(
os.path.join(os.path.dirname(__file__), "csv_translations")
)
locale = tornado.locale.get("fr_FR")
self.assertTrue(isinstance(locale, tornado.locale.CSVLocale))
self.assertEqual(locale.translate("school"), u"\u00e9cole")
def test_csv_bom(self):
with open(
os.path.join(os.path.dirname(__file__), "csv_translations", "fr_FR.csv"),
"rb",
) as f:
char_data = to_unicode(f.read())
# Re-encode our input data (which is utf-8 without BOM) in
# encodings that use the BOM and ensure that we can still load
# it. Note that utf-16-le and utf-16-be do not write a BOM,
# so we only test whichver variant is native to our platform.
for encoding in ["utf-8-sig", "utf-16"]:
tmpdir = tempfile.mkdtemp()
try:
with open(os.path.join(tmpdir, "fr_FR.csv"), "wb") as f:
f.write(char_data.encode(encoding))
tornado.locale.load_translations(tmpdir)
locale = tornado.locale.get("fr_FR")
self.assertIsInstance(locale, tornado.locale.CSVLocale)
self.assertEqual(locale.translate("school"), u"\u00e9cole")
finally:
shutil.rmtree(tmpdir)
def test_gettext(self):
tornado.locale.load_gettext_translations(
os.path.join(os.path.dirname(__file__), "gettext_translations"),
"tornado_test",
)
locale = tornado.locale.get("fr_FR")
self.assertTrue(isinstance(locale, tornado.locale.GettextLocale))
self.assertEqual(locale.translate("school"), u"\u00e9cole")
self.assertEqual(locale.pgettext("law", "right"), u"le droit")
self.assertEqual(locale.pgettext("good", "right"), u"le bien")
self.assertEqual(
locale.pgettext("organization", "club", "clubs", 1), u"le club"
)
self.assertEqual(
locale.pgettext("organization", "club", "clubs", 2), u"les clubs"
)
self.assertEqual(locale.pgettext("stick", "club", "clubs", 1), u"le b\xe2ton")
self.assertEqual(locale.pgettext("stick", "club", "clubs", 2), u"les b\xe2tons")
class LocaleDataTest(unittest.TestCase):
def test_non_ascii_name(self):
name = tornado.locale.LOCALE_NAMES["es_LA"]["name"]
self.assertTrue(isinstance(name, unicode_type))
self.assertEqual(name, u"Espa\u00f1ol")
self.assertEqual(utf8(name), b"Espa\xc3\xb1ol")
class EnglishTest(unittest.TestCase):
def test_format_date(self):
locale = tornado.locale.get("en_US")
date = datetime.datetime(2013, 4, 28, 18, 35)
self.assertEqual(
locale.format_date(date, full_format=True), "April 28, 2013 at 6:35 pm"
)
now = datetime.datetime.utcnow()
self.assertEqual(
locale.format_date(now - datetime.timedelta(seconds=2), full_format=False),
"2 seconds ago",
)
self.assertEqual(
locale.format_date(now - datetime.timedelta(minutes=2), full_format=False),
"2 minutes ago",
)
self.assertEqual(
locale.format_date(now - datetime.timedelta(hours=2), full_format=False),
"2 hours ago",
)
self.assertEqual(
locale.format_date(
now - datetime.timedelta(days=1), full_format=False, shorter=True
),
"yesterday",
)
date = now - datetime.timedelta(days=2)
self.assertEqual(
locale.format_date(date, full_format=False, shorter=True),
locale._weekdays[date.weekday()],
)
date = now - datetime.timedelta(days=300)
self.assertEqual(
locale.format_date(date, full_format=False, shorter=True),
"%s %d" % (locale._months[date.month - 1], date.day),
)
date = now - datetime.timedelta(days=500)
self.assertEqual(
locale.format_date(date, full_format=False, shorter=True),
"%s %d, %d" % (locale._months[date.month - 1], date.day, date.year),
)
def test_friendly_number(self):
locale = tornado.locale.get("en_US")
self.assertEqual(locale.friendly_number(1000000), "1,000,000")
def test_list(self):
locale = tornado.locale.get("en_US")
self.assertEqual(locale.list([]), "")
self.assertEqual(locale.list(["A"]), "A")
self.assertEqual(locale.list(["A", "B"]), "A and B")
self.assertEqual(locale.list(["A", "B", "C"]), "A, B and C")
def test_format_day(self):
locale = tornado.locale.get("en_US")
date = datetime.datetime(2013, 4, 28, 18, 35)
self.assertEqual(locale.format_day(date=date, dow=True), "Sunday, April 28")
self.assertEqual(locale.format_day(date=date, dow=False), "April 28")

View File

@@ -0,0 +1,535 @@
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import asyncio
from datetime import timedelta
import typing # noqa: F401
import unittest
from tornado import gen, locks
from tornado.gen import TimeoutError
from tornado.testing import gen_test, AsyncTestCase
class ConditionTest(AsyncTestCase):
def setUp(self):
super(ConditionTest, self).setUp()
self.history = [] # type: typing.List[typing.Union[int, str]]
def record_done(self, future, key):
"""Record the resolution of a Future returned by Condition.wait."""
def callback(_):
if not future.result():
# wait() resolved to False, meaning it timed out.
self.history.append("timeout")
else:
self.history.append(key)
future.add_done_callback(callback)
def loop_briefly(self):
"""Run all queued callbacks on the IOLoop.
In these tests, this method is used after calling notify() to
preserve the pre-5.0 behavior in which callbacks ran
synchronously.
"""
self.io_loop.add_callback(self.stop)
self.wait()
def test_repr(self):
c = locks.Condition()
self.assertIn("Condition", repr(c))
self.assertNotIn("waiters", repr(c))
c.wait()
self.assertIn("waiters", repr(c))
@gen_test
def test_notify(self):
c = locks.Condition()
self.io_loop.call_later(0.01, c.notify)
yield c.wait()
def test_notify_1(self):
c = locks.Condition()
self.record_done(c.wait(), "wait1")
self.record_done(c.wait(), "wait2")
c.notify(1)
self.loop_briefly()
self.history.append("notify1")
c.notify(1)
self.loop_briefly()
self.history.append("notify2")
self.assertEqual(["wait1", "notify1", "wait2", "notify2"], self.history)
def test_notify_n(self):
c = locks.Condition()
for i in range(6):
self.record_done(c.wait(), i)
c.notify(3)
self.loop_briefly()
# Callbacks execute in the order they were registered.
self.assertEqual(list(range(3)), self.history)
c.notify(1)
self.loop_briefly()
self.assertEqual(list(range(4)), self.history)
c.notify(2)
self.loop_briefly()
self.assertEqual(list(range(6)), self.history)
def test_notify_all(self):
c = locks.Condition()
for i in range(4):
self.record_done(c.wait(), i)
c.notify_all()
self.loop_briefly()
self.history.append("notify_all")
# Callbacks execute in the order they were registered.
self.assertEqual(list(range(4)) + ["notify_all"], self.history) # type: ignore
@gen_test
def test_wait_timeout(self):
c = locks.Condition()
wait = c.wait(timedelta(seconds=0.01))
self.io_loop.call_later(0.02, c.notify) # Too late.
yield gen.sleep(0.03)
self.assertFalse((yield wait))
@gen_test
def test_wait_timeout_preempted(self):
c = locks.Condition()
# This fires before the wait times out.
self.io_loop.call_later(0.01, c.notify)
wait = c.wait(timedelta(seconds=0.02))
yield gen.sleep(0.03)
yield wait # No TimeoutError.
@gen_test
def test_notify_n_with_timeout(self):
# Register callbacks 0, 1, 2, and 3. Callback 1 has a timeout.
# Wait for that timeout to expire, then do notify(2) and make
# sure everyone runs. Verifies that a timed-out callback does
# not count against the 'n' argument to notify().
c = locks.Condition()
self.record_done(c.wait(), 0)
self.record_done(c.wait(timedelta(seconds=0.01)), 1)
self.record_done(c.wait(), 2)
self.record_done(c.wait(), 3)
# Wait for callback 1 to time out.
yield gen.sleep(0.02)
self.assertEqual(["timeout"], self.history)
c.notify(2)
yield gen.sleep(0.01)
self.assertEqual(["timeout", 0, 2], self.history)
self.assertEqual(["timeout", 0, 2], self.history)
c.notify()
yield
self.assertEqual(["timeout", 0, 2, 3], self.history)
@gen_test
def test_notify_all_with_timeout(self):
c = locks.Condition()
self.record_done(c.wait(), 0)
self.record_done(c.wait(timedelta(seconds=0.01)), 1)
self.record_done(c.wait(), 2)
# Wait for callback 1 to time out.
yield gen.sleep(0.02)
self.assertEqual(["timeout"], self.history)
c.notify_all()
yield
self.assertEqual(["timeout", 0, 2], self.history)
@gen_test
def test_nested_notify(self):
# Ensure no notifications lost, even if notify() is reentered by a
# waiter calling notify().
c = locks.Condition()
# Three waiters.
futures = [asyncio.ensure_future(c.wait()) for _ in range(3)]
# First and second futures resolved. Second future reenters notify(),
# resolving third future.
futures[1].add_done_callback(lambda _: c.notify())
c.notify(2)
yield
self.assertTrue(all(f.done() for f in futures))
@gen_test
def test_garbage_collection(self):
# Test that timed-out waiters are occasionally cleaned from the queue.
c = locks.Condition()
for _ in range(101):
c.wait(timedelta(seconds=0.01))
future = asyncio.ensure_future(c.wait())
self.assertEqual(102, len(c._waiters))
# Let first 101 waiters time out, triggering a collection.
yield gen.sleep(0.02)
self.assertEqual(1, len(c._waiters))
# Final waiter is still active.
self.assertFalse(future.done())
c.notify()
self.assertTrue(future.done())
class EventTest(AsyncTestCase):
def test_repr(self):
event = locks.Event()
self.assertTrue("clear" in str(event))
self.assertFalse("set" in str(event))
event.set()
self.assertFalse("clear" in str(event))
self.assertTrue("set" in str(event))
def test_event(self):
e = locks.Event()
future_0 = asyncio.ensure_future(e.wait())
e.set()
future_1 = asyncio.ensure_future(e.wait())
e.clear()
future_2 = asyncio.ensure_future(e.wait())
self.assertTrue(future_0.done())
self.assertTrue(future_1.done())
self.assertFalse(future_2.done())
@gen_test
def test_event_timeout(self):
e = locks.Event()
with self.assertRaises(TimeoutError):
yield e.wait(timedelta(seconds=0.01))
# After a timed-out waiter, normal operation works.
self.io_loop.add_timeout(timedelta(seconds=0.01), e.set)
yield e.wait(timedelta(seconds=1))
def test_event_set_multiple(self):
e = locks.Event()
e.set()
e.set()
self.assertTrue(e.is_set())
def test_event_wait_clear(self):
e = locks.Event()
f0 = asyncio.ensure_future(e.wait())
e.clear()
f1 = asyncio.ensure_future(e.wait())
e.set()
self.assertTrue(f0.done())
self.assertTrue(f1.done())
class SemaphoreTest(AsyncTestCase):
def test_negative_value(self):
self.assertRaises(ValueError, locks.Semaphore, value=-1)
def test_repr(self):
sem = locks.Semaphore()
self.assertIn("Semaphore", repr(sem))
self.assertIn("unlocked,value:1", repr(sem))
sem.acquire()
self.assertIn("locked", repr(sem))
self.assertNotIn("waiters", repr(sem))
sem.acquire()
self.assertIn("waiters", repr(sem))
def test_acquire(self):
sem = locks.Semaphore()
f0 = asyncio.ensure_future(sem.acquire())
self.assertTrue(f0.done())
# Wait for release().
f1 = asyncio.ensure_future(sem.acquire())
self.assertFalse(f1.done())
f2 = asyncio.ensure_future(sem.acquire())
sem.release()
self.assertTrue(f1.done())
self.assertFalse(f2.done())
sem.release()
self.assertTrue(f2.done())
sem.release()
# Now acquire() is instant.
self.assertTrue(asyncio.ensure_future(sem.acquire()).done())
self.assertEqual(0, len(sem._waiters))
@gen_test
def test_acquire_timeout(self):
sem = locks.Semaphore(2)
yield sem.acquire()
yield sem.acquire()
acquire = sem.acquire(timedelta(seconds=0.01))
self.io_loop.call_later(0.02, sem.release) # Too late.
yield gen.sleep(0.3)
with self.assertRaises(gen.TimeoutError):
yield acquire
sem.acquire()
f = asyncio.ensure_future(sem.acquire())
self.assertFalse(f.done())
sem.release()
self.assertTrue(f.done())
@gen_test
def test_acquire_timeout_preempted(self):
sem = locks.Semaphore(1)
yield sem.acquire()
# This fires before the wait times out.
self.io_loop.call_later(0.01, sem.release)
acquire = sem.acquire(timedelta(seconds=0.02))
yield gen.sleep(0.03)
yield acquire # No TimeoutError.
def test_release_unacquired(self):
# Unbounded releases are allowed, and increment the semaphore's value.
sem = locks.Semaphore()
sem.release()
sem.release()
# Now the counter is 3. We can acquire three times before blocking.
self.assertTrue(asyncio.ensure_future(sem.acquire()).done())
self.assertTrue(asyncio.ensure_future(sem.acquire()).done())
self.assertTrue(asyncio.ensure_future(sem.acquire()).done())
self.assertFalse(asyncio.ensure_future(sem.acquire()).done())
@gen_test
def test_garbage_collection(self):
# Test that timed-out waiters are occasionally cleaned from the queue.
sem = locks.Semaphore(value=0)
futures = [
asyncio.ensure_future(sem.acquire(timedelta(seconds=0.01)))
for _ in range(101)
]
future = asyncio.ensure_future(sem.acquire())
self.assertEqual(102, len(sem._waiters))
# Let first 101 waiters time out, triggering a collection.
yield gen.sleep(0.02)
self.assertEqual(1, len(sem._waiters))
# Final waiter is still active.
self.assertFalse(future.done())
sem.release()
self.assertTrue(future.done())
# Prevent "Future exception was never retrieved" messages.
for future in futures:
self.assertRaises(TimeoutError, future.result)
class SemaphoreContextManagerTest(AsyncTestCase):
@gen_test
def test_context_manager(self):
sem = locks.Semaphore()
with (yield sem.acquire()) as yielded:
self.assertTrue(yielded is None)
# Semaphore was released and can be acquired again.
self.assertTrue(asyncio.ensure_future(sem.acquire()).done())
@gen_test
def test_context_manager_async_await(self):
# Repeat the above test using 'async with'.
sem = locks.Semaphore()
async def f():
async with sem as yielded:
self.assertTrue(yielded is None)
yield f()
# Semaphore was released and can be acquired again.
self.assertTrue(asyncio.ensure_future(sem.acquire()).done())
@gen_test
def test_context_manager_exception(self):
sem = locks.Semaphore()
with self.assertRaises(ZeroDivisionError):
with (yield sem.acquire()):
1 / 0
# Semaphore was released and can be acquired again.
self.assertTrue(asyncio.ensure_future(sem.acquire()).done())
@gen_test
def test_context_manager_timeout(self):
sem = locks.Semaphore()
with (yield sem.acquire(timedelta(seconds=0.01))):
pass
# Semaphore was released and can be acquired again.
self.assertTrue(asyncio.ensure_future(sem.acquire()).done())
@gen_test
def test_context_manager_timeout_error(self):
sem = locks.Semaphore(value=0)
with self.assertRaises(gen.TimeoutError):
with (yield sem.acquire(timedelta(seconds=0.01))):
pass
# Counter is still 0.
self.assertFalse(asyncio.ensure_future(sem.acquire()).done())
@gen_test
def test_context_manager_contended(self):
sem = locks.Semaphore()
history = []
@gen.coroutine
def f(index):
with (yield sem.acquire()):
history.append("acquired %d" % index)
yield gen.sleep(0.01)
history.append("release %d" % index)
yield [f(i) for i in range(2)]
expected_history = []
for i in range(2):
expected_history.extend(["acquired %d" % i, "release %d" % i])
self.assertEqual(expected_history, history)
@gen_test
def test_yield_sem(self):
# Ensure we catch a "with (yield sem)", which should be
# "with (yield sem.acquire())".
with self.assertRaises(gen.BadYieldError):
with (yield locks.Semaphore()):
pass
def test_context_manager_misuse(self):
# Ensure we catch a "with sem", which should be
# "with (yield sem.acquire())".
with self.assertRaises(RuntimeError):
with locks.Semaphore():
pass
class BoundedSemaphoreTest(AsyncTestCase):
def test_release_unacquired(self):
sem = locks.BoundedSemaphore()
self.assertRaises(ValueError, sem.release)
# Value is 0.
sem.acquire()
# Block on acquire().
future = asyncio.ensure_future(sem.acquire())
self.assertFalse(future.done())
sem.release()
self.assertTrue(future.done())
# Value is 1.
sem.release()
self.assertRaises(ValueError, sem.release)
class LockTests(AsyncTestCase):
def test_repr(self):
lock = locks.Lock()
# No errors.
repr(lock)
lock.acquire()
repr(lock)
def test_acquire_release(self):
lock = locks.Lock()
self.assertTrue(asyncio.ensure_future(lock.acquire()).done())
future = asyncio.ensure_future(lock.acquire())
self.assertFalse(future.done())
lock.release()
self.assertTrue(future.done())
@gen_test
def test_acquire_fifo(self):
lock = locks.Lock()
self.assertTrue(asyncio.ensure_future(lock.acquire()).done())
N = 5
history = []
@gen.coroutine
def f(idx):
with (yield lock.acquire()):
history.append(idx)
futures = [f(i) for i in range(N)]
self.assertFalse(any(future.done() for future in futures))
lock.release()
yield futures
self.assertEqual(list(range(N)), history)
@gen_test
def test_acquire_fifo_async_with(self):
# Repeat the above test using `async with lock:`
# instead of `with (yield lock.acquire()):`.
lock = locks.Lock()
self.assertTrue(asyncio.ensure_future(lock.acquire()).done())
N = 5
history = []
async def f(idx):
async with lock:
history.append(idx)
futures = [f(i) for i in range(N)]
lock.release()
yield futures
self.assertEqual(list(range(N)), history)
@gen_test
def test_acquire_timeout(self):
lock = locks.Lock()
lock.acquire()
with self.assertRaises(gen.TimeoutError):
yield lock.acquire(timeout=timedelta(seconds=0.01))
# Still locked.
self.assertFalse(asyncio.ensure_future(lock.acquire()).done())
def test_multi_release(self):
lock = locks.Lock()
self.assertRaises(RuntimeError, lock.release)
lock.acquire()
lock.release()
self.assertRaises(RuntimeError, lock.release)
@gen_test
def test_yield_lock(self):
# Ensure we catch a "with (yield lock)", which should be
# "with (yield lock.acquire())".
with self.assertRaises(gen.BadYieldError):
with (yield locks.Lock()):
pass
def test_context_manager_misuse(self):
# Ensure we catch a "with lock", which should be
# "with (yield lock.acquire())".
with self.assertRaises(RuntimeError):
with locks.Lock():
pass
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,245 @@
#
# Copyright 2012 Facebook
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import contextlib
import glob
import logging
import os
import re
import subprocess
import sys
import tempfile
import unittest
import warnings
from tornado.escape import utf8
from tornado.log import LogFormatter, define_logging_options, enable_pretty_logging
from tornado.options import OptionParser
from tornado.util import basestring_type
@contextlib.contextmanager
def ignore_bytes_warning():
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=BytesWarning)
yield
class LogFormatterTest(unittest.TestCase):
# Matches the output of a single logging call (which may be multiple lines
# if a traceback was included, so we use the DOTALL option)
LINE_RE = re.compile(
b"(?s)\x01\\[E [0-9]{6} [0-9]{2}:[0-9]{2}:[0-9]{2} log_test:[0-9]+\\]\x02 (.*)"
)
def setUp(self):
self.formatter = LogFormatter(color=False)
# Fake color support. We can't guarantee anything about the $TERM
# variable when the tests are run, so just patch in some values
# for testing. (testing with color off fails to expose some potential
# encoding issues from the control characters)
self.formatter._colors = {logging.ERROR: u"\u0001"}
self.formatter._normal = u"\u0002"
# construct a Logger directly to bypass getLogger's caching
self.logger = logging.Logger("LogFormatterTest")
self.logger.propagate = False
self.tempdir = tempfile.mkdtemp()
self.filename = os.path.join(self.tempdir, "log.out")
self.handler = self.make_handler(self.filename)
self.handler.setFormatter(self.formatter)
self.logger.addHandler(self.handler)
def tearDown(self):
self.handler.close()
os.unlink(self.filename)
os.rmdir(self.tempdir)
def make_handler(self, filename):
# Base case: default setup without explicit encoding.
# In python 2, supports arbitrary byte strings and unicode objects
# that contain only ascii. In python 3, supports ascii-only unicode
# strings (but byte strings will be repr'd automatically).
return logging.FileHandler(filename)
def get_output(self):
with open(self.filename, "rb") as f:
line = f.read().strip()
m = LogFormatterTest.LINE_RE.match(line)
if m:
return m.group(1)
else:
raise Exception("output didn't match regex: %r" % line)
def test_basic_logging(self):
self.logger.error("foo")
self.assertEqual(self.get_output(), b"foo")
def test_bytes_logging(self):
with ignore_bytes_warning():
# This will be "\xe9" on python 2 or "b'\xe9'" on python 3
self.logger.error(b"\xe9")
self.assertEqual(self.get_output(), utf8(repr(b"\xe9")))
def test_utf8_logging(self):
with ignore_bytes_warning():
self.logger.error(u"\u00e9".encode("utf8"))
if issubclass(bytes, basestring_type):
# on python 2, utf8 byte strings (and by extension ascii byte
# strings) are passed through as-is.
self.assertEqual(self.get_output(), utf8(u"\u00e9"))
else:
# on python 3, byte strings always get repr'd even if
# they're ascii-only, so this degenerates into another
# copy of test_bytes_logging.
self.assertEqual(self.get_output(), utf8(repr(utf8(u"\u00e9"))))
def test_bytes_exception_logging(self):
try:
raise Exception(b"\xe9")
except Exception:
self.logger.exception("caught exception")
# This will be "Exception: \xe9" on python 2 or
# "Exception: b'\xe9'" on python 3.
output = self.get_output()
self.assertRegexpMatches(output, br"Exception.*\\xe9")
# The traceback contains newlines, which should not have been escaped.
self.assertNotIn(br"\n", output)
class UnicodeLogFormatterTest(LogFormatterTest):
def make_handler(self, filename):
# Adding an explicit encoding configuration allows non-ascii unicode
# strings in both python 2 and 3, without changing the behavior
# for byte strings.
return logging.FileHandler(filename, encoding="utf8")
def test_unicode_logging(self):
self.logger.error(u"\u00e9")
self.assertEqual(self.get_output(), utf8(u"\u00e9"))
class EnablePrettyLoggingTest(unittest.TestCase):
def setUp(self):
super(EnablePrettyLoggingTest, self).setUp()
self.options = OptionParser()
define_logging_options(self.options)
self.logger = logging.Logger("tornado.test.log_test.EnablePrettyLoggingTest")
self.logger.propagate = False
def test_log_file(self):
tmpdir = tempfile.mkdtemp()
try:
self.options.log_file_prefix = tmpdir + "/test_log"
enable_pretty_logging(options=self.options, logger=self.logger)
self.assertEqual(1, len(self.logger.handlers))
self.logger.error("hello")
self.logger.handlers[0].flush()
filenames = glob.glob(tmpdir + "/test_log*")
self.assertEqual(1, len(filenames))
with open(filenames[0]) as f:
self.assertRegexpMatches(f.read(), r"^\[E [^]]*\] hello$")
finally:
for handler in self.logger.handlers:
handler.flush()
handler.close()
for filename in glob.glob(tmpdir + "/test_log*"):
os.unlink(filename)
os.rmdir(tmpdir)
def test_log_file_with_timed_rotating(self):
tmpdir = tempfile.mkdtemp()
try:
self.options.log_file_prefix = tmpdir + "/test_log"
self.options.log_rotate_mode = "time"
enable_pretty_logging(options=self.options, logger=self.logger)
self.logger.error("hello")
self.logger.handlers[0].flush()
filenames = glob.glob(tmpdir + "/test_log*")
self.assertEqual(1, len(filenames))
with open(filenames[0]) as f:
self.assertRegexpMatches(f.read(), r"^\[E [^]]*\] hello$")
finally:
for handler in self.logger.handlers:
handler.flush()
handler.close()
for filename in glob.glob(tmpdir + "/test_log*"):
os.unlink(filename)
os.rmdir(tmpdir)
def test_wrong_rotate_mode_value(self):
try:
self.options.log_file_prefix = "some_path"
self.options.log_rotate_mode = "wrong_mode"
self.assertRaises(
ValueError,
enable_pretty_logging,
options=self.options,
logger=self.logger,
)
finally:
for handler in self.logger.handlers:
handler.flush()
handler.close()
class LoggingOptionTest(unittest.TestCase):
"""Test the ability to enable and disable Tornado's logging hooks."""
def logs_present(self, statement, args=None):
# Each test may manipulate and/or parse the options and then logs
# a line at the 'info' level. This level is ignored in the
# logging module by default, but Tornado turns it on by default
# so it is the easiest way to tell whether tornado's logging hooks
# ran.
IMPORT = "from tornado.options import options, parse_command_line"
LOG_INFO = 'import logging; logging.info("hello")'
program = ";".join([IMPORT, statement, LOG_INFO])
proc = subprocess.Popen(
[sys.executable, "-c", program] + (args or []),
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
)
stdout, stderr = proc.communicate()
self.assertEqual(proc.returncode, 0, "process failed: %r" % stdout)
return b"hello" in stdout
def test_default(self):
self.assertFalse(self.logs_present("pass"))
def test_tornado_default(self):
self.assertTrue(self.logs_present("parse_command_line()"))
def test_disable_command_line(self):
self.assertFalse(self.logs_present("parse_command_line()", ["--logging=none"]))
def test_disable_command_line_case_insensitive(self):
self.assertFalse(self.logs_present("parse_command_line()", ["--logging=None"]))
def test_disable_code_string(self):
self.assertFalse(
self.logs_present('options.logging = "none"; parse_command_line()')
)
def test_disable_code_none(self):
self.assertFalse(
self.logs_present("options.logging = None; parse_command_line()")
)
def test_disable_override(self):
# command line trumps code defaults
self.assertTrue(
self.logs_present(
"options.logging = None; parse_command_line()", ["--logging=info"]
)
)

View File

@@ -0,0 +1,226 @@
import errno
import os
import signal
import socket
from subprocess import Popen
import sys
import time
import unittest
from tornado.netutil import (
BlockingResolver,
OverrideResolver,
ThreadedResolver,
is_valid_ip,
bind_sockets,
)
from tornado.testing import AsyncTestCase, gen_test, bind_unused_port
from tornado.test.util import skipIfNoNetwork
import typing
if typing.TYPE_CHECKING:
from typing import List # noqa: F401
try:
import pycares # type: ignore
except ImportError:
pycares = None
else:
from tornado.platform.caresresolver import CaresResolver
try:
import twisted # type: ignore
import twisted.names # type: ignore
except ImportError:
twisted = None
else:
from tornado.platform.twisted import TwistedResolver
class _ResolverTestMixin(object):
@gen_test
def test_localhost(self):
addrinfo = yield self.resolver.resolve("localhost", 80, socket.AF_UNSPEC)
self.assertIn((socket.AF_INET, ("127.0.0.1", 80)), addrinfo)
# It is impossible to quickly and consistently generate an error in name
# resolution, so test this case separately, using mocks as needed.
class _ResolverErrorTestMixin(object):
@gen_test
def test_bad_host(self):
with self.assertRaises(IOError):
yield self.resolver.resolve("an invalid domain", 80, socket.AF_UNSPEC)
def _failing_getaddrinfo(*args):
"""Dummy implementation of getaddrinfo for use in mocks"""
raise socket.gaierror(errno.EIO, "mock: lookup failed")
@skipIfNoNetwork
class BlockingResolverTest(AsyncTestCase, _ResolverTestMixin):
def setUp(self):
super(BlockingResolverTest, self).setUp()
self.resolver = BlockingResolver()
# getaddrinfo-based tests need mocking to reliably generate errors;
# some configurations are slow to produce errors and take longer than
# our default timeout.
class BlockingResolverErrorTest(AsyncTestCase, _ResolverErrorTestMixin):
def setUp(self):
super(BlockingResolverErrorTest, self).setUp()
self.resolver = BlockingResolver()
self.real_getaddrinfo = socket.getaddrinfo
socket.getaddrinfo = _failing_getaddrinfo
def tearDown(self):
socket.getaddrinfo = self.real_getaddrinfo
super(BlockingResolverErrorTest, self).tearDown()
class OverrideResolverTest(AsyncTestCase, _ResolverTestMixin):
def setUp(self):
super(OverrideResolverTest, self).setUp()
mapping = {
("google.com", 80): ("1.2.3.4", 80),
("google.com", 80, socket.AF_INET): ("1.2.3.4", 80),
("google.com", 80, socket.AF_INET6): (
"2a02:6b8:7c:40c:c51e:495f:e23a:3",
80,
),
}
self.resolver = OverrideResolver(BlockingResolver(), mapping)
@gen_test
def test_resolve_multiaddr(self):
result = yield self.resolver.resolve("google.com", 80, socket.AF_INET)
self.assertIn((socket.AF_INET, ("1.2.3.4", 80)), result)
result = yield self.resolver.resolve("google.com", 80, socket.AF_INET6)
self.assertIn(
(socket.AF_INET6, ("2a02:6b8:7c:40c:c51e:495f:e23a:3", 80, 0, 0)), result
)
@skipIfNoNetwork
class ThreadedResolverTest(AsyncTestCase, _ResolverTestMixin):
def setUp(self):
super(ThreadedResolverTest, self).setUp()
self.resolver = ThreadedResolver()
def tearDown(self):
self.resolver.close()
super(ThreadedResolverTest, self).tearDown()
class ThreadedResolverErrorTest(AsyncTestCase, _ResolverErrorTestMixin):
def setUp(self):
super(ThreadedResolverErrorTest, self).setUp()
self.resolver = BlockingResolver()
self.real_getaddrinfo = socket.getaddrinfo
socket.getaddrinfo = _failing_getaddrinfo
def tearDown(self):
socket.getaddrinfo = self.real_getaddrinfo
super(ThreadedResolverErrorTest, self).tearDown()
@skipIfNoNetwork
@unittest.skipIf(sys.platform == "win32", "preexec_fn not available on win32")
class ThreadedResolverImportTest(unittest.TestCase):
def test_import(self):
TIMEOUT = 5
# Test for a deadlock when importing a module that runs the
# ThreadedResolver at import-time. See resolve_test.py for
# full explanation.
command = [sys.executable, "-c", "import tornado.test.resolve_test_helper"]
start = time.time()
popen = Popen(command, preexec_fn=lambda: signal.alarm(TIMEOUT))
while time.time() - start < TIMEOUT:
return_code = popen.poll()
if return_code is not None:
self.assertEqual(0, return_code)
return # Success.
time.sleep(0.05)
self.fail("import timed out")
# We do not test errors with CaresResolver:
# Some DNS-hijacking ISPs (e.g. Time Warner) return non-empty results
# with an NXDOMAIN status code. Most resolvers treat this as an error;
# C-ares returns the results, making the "bad_host" tests unreliable.
# C-ares will try to resolve even malformed names, such as the
# name with spaces used in this test.
@skipIfNoNetwork
@unittest.skipIf(pycares is None, "pycares module not present")
class CaresResolverTest(AsyncTestCase, _ResolverTestMixin):
def setUp(self):
super(CaresResolverTest, self).setUp()
self.resolver = CaresResolver()
# TwistedResolver produces consistent errors in our test cases so we
# could test the regular and error cases in the same class. However,
# in the error cases it appears that cleanup of socket objects is
# handled asynchronously and occasionally results in "unclosed socket"
# warnings if not given time to shut down (and there is no way to
# explicitly shut it down). This makes the test flaky, so we do not
# test error cases here.
@skipIfNoNetwork
@unittest.skipIf(twisted is None, "twisted module not present")
@unittest.skipIf(
getattr(twisted, "__version__", "0.0") < "12.1", "old version of twisted"
)
class TwistedResolverTest(AsyncTestCase, _ResolverTestMixin):
def setUp(self):
super(TwistedResolverTest, self).setUp()
self.resolver = TwistedResolver()
class IsValidIPTest(unittest.TestCase):
def test_is_valid_ip(self):
self.assertTrue(is_valid_ip("127.0.0.1"))
self.assertTrue(is_valid_ip("4.4.4.4"))
self.assertTrue(is_valid_ip("::1"))
self.assertTrue(is_valid_ip("2620:0:1cfe:face:b00c::3"))
self.assertTrue(not is_valid_ip("www.google.com"))
self.assertTrue(not is_valid_ip("localhost"))
self.assertTrue(not is_valid_ip("4.4.4.4<"))
self.assertTrue(not is_valid_ip(" 127.0.0.1"))
self.assertTrue(not is_valid_ip(""))
self.assertTrue(not is_valid_ip(" "))
self.assertTrue(not is_valid_ip("\n"))
self.assertTrue(not is_valid_ip("\x00"))
class TestPortAllocation(unittest.TestCase):
def test_same_port_allocation(self):
if "TRAVIS" in os.environ:
self.skipTest("dual-stack servers often have port conflicts on travis")
sockets = bind_sockets(0, "localhost")
try:
port = sockets[0].getsockname()[1]
self.assertTrue(all(s.getsockname()[1] == port for s in sockets[1:]))
finally:
for sock in sockets:
sock.close()
@unittest.skipIf(
not hasattr(socket, "SO_REUSEPORT"), "SO_REUSEPORT is not supported"
)
def test_reuse_port(self):
sockets = [] # type: List[socket.socket]
socket, port = bind_unused_port(reuse_port=True)
try:
sockets = bind_sockets(port, "127.0.0.1", reuse_port=True)
self.assertTrue(all(s.getsockname()[1] == port for s in sockets))
finally:
socket.close()
for sock in sockets:
sock.close()

View File

@@ -0,0 +1,7 @@
port=443
port=443
username='李康'
foo_bar='a'
my_path = __file__

View File

@@ -0,0 +1,329 @@
# -*- coding: utf-8 -*-
import datetime
from io import StringIO
import os
import sys
from unittest import mock
import unittest
from tornado.options import OptionParser, Error
from tornado.util import basestring_type
from tornado.test.util import subTest
import typing
if typing.TYPE_CHECKING:
from typing import List # noqa: F401
class Email(object):
def __init__(self, value):
if isinstance(value, str) and "@" in value:
self._value = value
else:
raise ValueError()
@property
def value(self):
return self._value
class OptionsTest(unittest.TestCase):
def test_parse_command_line(self):
options = OptionParser()
options.define("port", default=80)
options.parse_command_line(["main.py", "--port=443"])
self.assertEqual(options.port, 443)
def test_parse_config_file(self):
options = OptionParser()
options.define("port", default=80)
options.define("username", default="foo")
options.define("my_path")
config_path = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "options_test.cfg"
)
options.parse_config_file(config_path)
self.assertEqual(options.port, 443)
self.assertEqual(options.username, "李康")
self.assertEqual(options.my_path, config_path)
def test_parse_callbacks(self):
options = OptionParser()
self.called = False
def callback():
self.called = True
options.add_parse_callback(callback)
# non-final parse doesn't run callbacks
options.parse_command_line(["main.py"], final=False)
self.assertFalse(self.called)
# final parse does
options.parse_command_line(["main.py"])
self.assertTrue(self.called)
# callbacks can be run more than once on the same options
# object if there are multiple final parses
self.called = False
options.parse_command_line(["main.py"])
self.assertTrue(self.called)
def test_help(self):
options = OptionParser()
try:
orig_stderr = sys.stderr
sys.stderr = StringIO()
with self.assertRaises(SystemExit):
options.parse_command_line(["main.py", "--help"])
usage = sys.stderr.getvalue()
finally:
sys.stderr = orig_stderr
self.assertIn("Usage:", usage)
def test_subcommand(self):
base_options = OptionParser()
base_options.define("verbose", default=False)
sub_options = OptionParser()
sub_options.define("foo", type=str)
rest = base_options.parse_command_line(
["main.py", "--verbose", "subcommand", "--foo=bar"]
)
self.assertEqual(rest, ["subcommand", "--foo=bar"])
self.assertTrue(base_options.verbose)
rest2 = sub_options.parse_command_line(rest)
self.assertEqual(rest2, [])
self.assertEqual(sub_options.foo, "bar")
# the two option sets are distinct
try:
orig_stderr = sys.stderr
sys.stderr = StringIO()
with self.assertRaises(Error):
sub_options.parse_command_line(["subcommand", "--verbose"])
finally:
sys.stderr = orig_stderr
def test_setattr(self):
options = OptionParser()
options.define("foo", default=1, type=int)
options.foo = 2
self.assertEqual(options.foo, 2)
def test_setattr_type_check(self):
# setattr requires that options be the right type and doesn't
# parse from string formats.
options = OptionParser()
options.define("foo", default=1, type=int)
with self.assertRaises(Error):
options.foo = "2"
def test_setattr_with_callback(self):
values = [] # type: List[int]
options = OptionParser()
options.define("foo", default=1, type=int, callback=values.append)
options.foo = 2
self.assertEqual(values, [2])
def _sample_options(self):
options = OptionParser()
options.define("a", default=1)
options.define("b", default=2)
return options
def test_iter(self):
options = self._sample_options()
# OptionParsers always define 'help'.
self.assertEqual(set(["a", "b", "help"]), set(iter(options)))
def test_getitem(self):
options = self._sample_options()
self.assertEqual(1, options["a"])
def test_setitem(self):
options = OptionParser()
options.define("foo", default=1, type=int)
options["foo"] = 2
self.assertEqual(options["foo"], 2)
def test_items(self):
options = self._sample_options()
# OptionParsers always define 'help'.
expected = [("a", 1), ("b", 2), ("help", options.help)]
actual = sorted(options.items())
self.assertEqual(expected, actual)
def test_as_dict(self):
options = self._sample_options()
expected = {"a": 1, "b": 2, "help": options.help}
self.assertEqual(expected, options.as_dict())
def test_group_dict(self):
options = OptionParser()
options.define("a", default=1)
options.define("b", group="b_group", default=2)
frame = sys._getframe(0)
this_file = frame.f_code.co_filename
self.assertEqual(set(["b_group", "", this_file]), options.groups())
b_group_dict = options.group_dict("b_group")
self.assertEqual({"b": 2}, b_group_dict)
self.assertEqual({}, options.group_dict("nonexistent"))
def test_mock_patch(self):
# ensure that our setattr hooks don't interfere with mock.patch
options = OptionParser()
options.define("foo", default=1)
options.parse_command_line(["main.py", "--foo=2"])
self.assertEqual(options.foo, 2)
with mock.patch.object(options.mockable(), "foo", 3):
self.assertEqual(options.foo, 3)
self.assertEqual(options.foo, 2)
# Try nested patches mixed with explicit sets
with mock.patch.object(options.mockable(), "foo", 4):
self.assertEqual(options.foo, 4)
options.foo = 5
self.assertEqual(options.foo, 5)
with mock.patch.object(options.mockable(), "foo", 6):
self.assertEqual(options.foo, 6)
self.assertEqual(options.foo, 5)
self.assertEqual(options.foo, 2)
def _define_options(self):
options = OptionParser()
options.define("str", type=str)
options.define("basestring", type=basestring_type)
options.define("int", type=int)
options.define("float", type=float)
options.define("datetime", type=datetime.datetime)
options.define("timedelta", type=datetime.timedelta)
options.define("email", type=Email)
options.define("list-of-int", type=int, multiple=True)
return options
def _check_options_values(self, options):
self.assertEqual(options.str, "asdf")
self.assertEqual(options.basestring, "qwer")
self.assertEqual(options.int, 42)
self.assertEqual(options.float, 1.5)
self.assertEqual(options.datetime, datetime.datetime(2013, 4, 28, 5, 16))
self.assertEqual(options.timedelta, datetime.timedelta(seconds=45))
self.assertEqual(options.email.value, "tornado@web.com")
self.assertTrue(isinstance(options.email, Email))
self.assertEqual(options.list_of_int, [1, 2, 3])
def test_types(self):
options = self._define_options()
options.parse_command_line(
[
"main.py",
"--str=asdf",
"--basestring=qwer",
"--int=42",
"--float=1.5",
"--datetime=2013-04-28 05:16",
"--timedelta=45s",
"--email=tornado@web.com",
"--list-of-int=1,2,3",
]
)
self._check_options_values(options)
def test_types_with_conf_file(self):
for config_file_name in (
"options_test_types.cfg",
"options_test_types_str.cfg",
):
options = self._define_options()
options.parse_config_file(
os.path.join(os.path.dirname(__file__), config_file_name)
)
self._check_options_values(options)
def test_multiple_string(self):
options = OptionParser()
options.define("foo", type=str, multiple=True)
options.parse_command_line(["main.py", "--foo=a,b,c"])
self.assertEqual(options.foo, ["a", "b", "c"])
def test_multiple_int(self):
options = OptionParser()
options.define("foo", type=int, multiple=True)
options.parse_command_line(["main.py", "--foo=1,3,5:7"])
self.assertEqual(options.foo, [1, 3, 5, 6, 7])
def test_error_redefine(self):
options = OptionParser()
options.define("foo")
with self.assertRaises(Error) as cm:
options.define("foo")
self.assertRegexpMatches(str(cm.exception), "Option.*foo.*already defined")
def test_error_redefine_underscore(self):
# Ensure that the dash/underscore normalization doesn't
# interfere with the redefinition error.
tests = [
("foo-bar", "foo-bar"),
("foo_bar", "foo_bar"),
("foo-bar", "foo_bar"),
("foo_bar", "foo-bar"),
]
for a, b in tests:
with subTest(self, a=a, b=b):
options = OptionParser()
options.define(a)
with self.assertRaises(Error) as cm:
options.define(b)
self.assertRegexpMatches(
str(cm.exception), "Option.*foo.bar.*already defined"
)
def test_dash_underscore_cli(self):
# Dashes and underscores should be interchangeable.
for defined_name in ["foo-bar", "foo_bar"]:
for flag in ["--foo-bar=a", "--foo_bar=a"]:
options = OptionParser()
options.define(defined_name)
options.parse_command_line(["main.py", flag])
# Attr-style access always uses underscores.
self.assertEqual(options.foo_bar, "a")
# Dict-style access allows both.
self.assertEqual(options["foo-bar"], "a")
self.assertEqual(options["foo_bar"], "a")
def test_dash_underscore_file(self):
# No matter how an option was defined, it can be set with underscores
# in a config file.
for defined_name in ["foo-bar", "foo_bar"]:
options = OptionParser()
options.define(defined_name)
options.parse_config_file(
os.path.join(os.path.dirname(__file__), "options_test.cfg")
)
self.assertEqual(options.foo_bar, "a")
def test_dash_underscore_introspection(self):
# Original names are preserved in introspection APIs.
options = OptionParser()
options.define("with-dash", group="g")
options.define("with_underscore", group="g")
all_options = ["help", "with-dash", "with_underscore"]
self.assertEqual(sorted(options), all_options)
self.assertEqual(sorted(k for (k, v) in options.items()), all_options)
self.assertEqual(sorted(options.as_dict().keys()), all_options)
self.assertEqual(
sorted(options.group_dict("g")), ["with-dash", "with_underscore"]
)
# --help shows CLI-style names with dashes.
buf = StringIO()
options.print_help(buf)
self.assertIn("--with-dash", buf.getvalue())
self.assertIn("--with-underscore", buf.getvalue())

View File

@@ -0,0 +1,11 @@
from datetime import datetime, timedelta
from tornado.test.options_test import Email
str = 'asdf'
basestring = 'qwer'
int = 42
float = 1.5
datetime = datetime(2013, 4, 28, 5, 16)
timedelta = timedelta(0, 45)
email = Email('tornado@web.com')
list_of_int = [1, 2, 3]

View File

@@ -0,0 +1,8 @@
str = 'asdf'
basestring = 'qwer'
int = 42
float = 1.5
datetime = '2013-04-28 05:16'
timedelta = '45s'
email = 'tornado@web.com'
list_of_int = '1,2,3'

View File

@@ -0,0 +1,265 @@
import asyncio
import logging
import os
import signal
import subprocess
import sys
import unittest
from tornado.httpclient import HTTPClient, HTTPError
from tornado.httpserver import HTTPServer
from tornado.ioloop import IOLoop
from tornado.log import gen_log
from tornado.process import fork_processes, task_id, Subprocess
from tornado.simple_httpclient import SimpleAsyncHTTPClient
from tornado.testing import bind_unused_port, ExpectLog, AsyncTestCase, gen_test
from tornado.test.util import skipIfNonUnix
from tornado.web import RequestHandler, Application
# Not using AsyncHTTPTestCase because we need control over the IOLoop.
@skipIfNonUnix
class ProcessTest(unittest.TestCase):
def get_app(self):
class ProcessHandler(RequestHandler):
def get(self):
if self.get_argument("exit", None):
# must use os._exit instead of sys.exit so unittest's
# exception handler doesn't catch it
os._exit(int(self.get_argument("exit")))
if self.get_argument("signal", None):
os.kill(os.getpid(), int(self.get_argument("signal")))
self.write(str(os.getpid()))
return Application([("/", ProcessHandler)])
def tearDown(self):
if task_id() is not None:
# We're in a child process, and probably got to this point
# via an uncaught exception. If we return now, both
# processes will continue with the rest of the test suite.
# Exit now so the parent process will restart the child
# (since we don't have a clean way to signal failure to
# the parent that won't restart)
logging.error("aborting child process from tearDown")
logging.shutdown()
os._exit(1)
# In the surviving process, clear the alarm we set earlier
signal.alarm(0)
super(ProcessTest, self).tearDown()
def test_multi_process(self):
# This test doesn't work on twisted because we use the global
# reactor and don't restore it to a sane state after the fork
# (asyncio has the same issue, but we have a special case in
# place for it).
with ExpectLog(
gen_log, "(Starting .* processes|child .* exited|uncaught exception)"
):
sock, port = bind_unused_port()
def get_url(path):
return "http://127.0.0.1:%d%s" % (port, path)
# ensure that none of these processes live too long
signal.alarm(5) # master process
try:
id = fork_processes(3, max_restarts=3)
self.assertTrue(id is not None)
signal.alarm(5) # child processes
except SystemExit as e:
# if we exit cleanly from fork_processes, all the child processes
# finished with status 0
self.assertEqual(e.code, 0)
self.assertTrue(task_id() is None)
sock.close()
return
try:
if asyncio is not None:
# Reset the global asyncio event loop, which was put into
# a broken state by the fork.
asyncio.set_event_loop(asyncio.new_event_loop())
if id in (0, 1):
self.assertEqual(id, task_id())
server = HTTPServer(self.get_app())
server.add_sockets([sock])
IOLoop.current().start()
elif id == 2:
self.assertEqual(id, task_id())
sock.close()
# Always use SimpleAsyncHTTPClient here; the curl
# version appears to get confused sometimes if the
# connection gets closed before it's had a chance to
# switch from writing mode to reading mode.
client = HTTPClient(SimpleAsyncHTTPClient)
def fetch(url, fail_ok=False):
try:
return client.fetch(get_url(url))
except HTTPError as e:
if not (fail_ok and e.code == 599):
raise
# Make two processes exit abnormally
fetch("/?exit=2", fail_ok=True)
fetch("/?exit=3", fail_ok=True)
# They've been restarted, so a new fetch will work
int(fetch("/").body)
# Now the same with signals
# Disabled because on the mac a process dying with a signal
# can trigger an "Application exited abnormally; send error
# report to Apple?" prompt.
# fetch("/?signal=%d" % signal.SIGTERM, fail_ok=True)
# fetch("/?signal=%d" % signal.SIGABRT, fail_ok=True)
# int(fetch("/").body)
# Now kill them normally so they won't be restarted
fetch("/?exit=0", fail_ok=True)
# One process left; watch it's pid change
pid = int(fetch("/").body)
fetch("/?exit=4", fail_ok=True)
pid2 = int(fetch("/").body)
self.assertNotEqual(pid, pid2)
# Kill the last one so we shut down cleanly
fetch("/?exit=0", fail_ok=True)
os._exit(0)
except Exception:
logging.error("exception in child process %d", id, exc_info=True)
raise
@skipIfNonUnix
class SubprocessTest(AsyncTestCase):
def term_and_wait(self, subproc):
subproc.proc.terminate()
subproc.proc.wait()
@gen_test
def test_subprocess(self):
if IOLoop.configured_class().__name__.endswith("LayeredTwistedIOLoop"):
# This test fails non-deterministically with LayeredTwistedIOLoop.
# (the read_until('\n') returns '\n' instead of 'hello\n')
# This probably indicates a problem with either TornadoReactor
# or TwistedIOLoop, but I haven't been able to track it down
# and for now this is just causing spurious travis-ci failures.
raise unittest.SkipTest(
"Subprocess tests not compatible with " "LayeredTwistedIOLoop"
)
subproc = Subprocess(
[sys.executable, "-u", "-i"],
stdin=Subprocess.STREAM,
stdout=Subprocess.STREAM,
stderr=subprocess.STDOUT,
)
self.addCleanup(lambda: self.term_and_wait(subproc))
self.addCleanup(subproc.stdout.close)
self.addCleanup(subproc.stdin.close)
yield subproc.stdout.read_until(b">>> ")
subproc.stdin.write(b"print('hello')\n")
data = yield subproc.stdout.read_until(b"\n")
self.assertEqual(data, b"hello\n")
yield subproc.stdout.read_until(b">>> ")
subproc.stdin.write(b"raise SystemExit\n")
data = yield subproc.stdout.read_until_close()
self.assertEqual(data, b"")
@gen_test
def test_close_stdin(self):
# Close the parent's stdin handle and see that the child recognizes it.
subproc = Subprocess(
[sys.executable, "-u", "-i"],
stdin=Subprocess.STREAM,
stdout=Subprocess.STREAM,
stderr=subprocess.STDOUT,
)
self.addCleanup(lambda: self.term_and_wait(subproc))
yield subproc.stdout.read_until(b">>> ")
subproc.stdin.close()
data = yield subproc.stdout.read_until_close()
self.assertEqual(data, b"\n")
@gen_test
def test_stderr(self):
# This test is mysteriously flaky on twisted: it succeeds, but logs
# an error of EBADF on closing a file descriptor.
subproc = Subprocess(
[sys.executable, "-u", "-c", r"import sys; sys.stderr.write('hello\n')"],
stderr=Subprocess.STREAM,
)
self.addCleanup(lambda: self.term_and_wait(subproc))
data = yield subproc.stderr.read_until(b"\n")
self.assertEqual(data, b"hello\n")
# More mysterious EBADF: This fails if done with self.addCleanup instead of here.
subproc.stderr.close()
def test_sigchild(self):
Subprocess.initialize()
self.addCleanup(Subprocess.uninitialize)
subproc = Subprocess([sys.executable, "-c", "pass"])
subproc.set_exit_callback(self.stop)
ret = self.wait()
self.assertEqual(ret, 0)
self.assertEqual(subproc.returncode, ret)
@gen_test
def test_sigchild_future(self):
Subprocess.initialize()
self.addCleanup(Subprocess.uninitialize)
subproc = Subprocess([sys.executable, "-c", "pass"])
ret = yield subproc.wait_for_exit()
self.assertEqual(ret, 0)
self.assertEqual(subproc.returncode, ret)
def test_sigchild_signal(self):
Subprocess.initialize()
self.addCleanup(Subprocess.uninitialize)
subproc = Subprocess(
[sys.executable, "-c", "import time; time.sleep(30)"],
stdout=Subprocess.STREAM,
)
self.addCleanup(subproc.stdout.close)
subproc.set_exit_callback(self.stop)
os.kill(subproc.pid, signal.SIGTERM)
try:
ret = self.wait(timeout=1.0)
except AssertionError:
# We failed to get the termination signal. This test is
# occasionally flaky on pypy, so try to get a little more
# information: did the process close its stdout
# (indicating that the problem is in the parent process's
# signal handling) or did the child process somehow fail
# to terminate?
fut = subproc.stdout.read_until_close()
fut.add_done_callback(lambda f: self.stop()) # type: ignore
try:
self.wait(timeout=1.0)
except AssertionError:
raise AssertionError("subprocess failed to terminate")
else:
raise AssertionError(
"subprocess closed stdout but failed to " "get termination signal"
)
self.assertEqual(subproc.returncode, ret)
self.assertEqual(ret, -signal.SIGTERM)
@gen_test
def test_wait_for_exit_raise(self):
Subprocess.initialize()
self.addCleanup(Subprocess.uninitialize)
subproc = Subprocess([sys.executable, "-c", "import sys; sys.exit(1)"])
with self.assertRaises(subprocess.CalledProcessError) as cm:
yield subproc.wait_for_exit()
self.assertEqual(cm.exception.returncode, 1)
@gen_test
def test_wait_for_exit_raise_disabled(self):
Subprocess.initialize()
self.addCleanup(Subprocess.uninitialize)
subproc = Subprocess([sys.executable, "-c", "import sys; sys.exit(1)"])
ret = yield subproc.wait_for_exit(raise_error=False)
self.assertEqual(ret, 1)

View File

@@ -0,0 +1,427 @@
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import asyncio
from datetime import timedelta
from random import random
import unittest
from tornado import gen, queues
from tornado.gen import TimeoutError
from tornado.testing import gen_test, AsyncTestCase
class QueueBasicTest(AsyncTestCase):
def test_repr_and_str(self):
q = queues.Queue(maxsize=1) # type: queues.Queue[None]
self.assertIn(hex(id(q)), repr(q))
self.assertNotIn(hex(id(q)), str(q))
q.get()
for q_str in repr(q), str(q):
self.assertTrue(q_str.startswith("<Queue"))
self.assertIn("maxsize=1", q_str)
self.assertIn("getters[1]", q_str)
self.assertNotIn("putters", q_str)
self.assertNotIn("tasks", q_str)
q.put(None)
q.put(None)
# Now the queue is full, this putter blocks.
q.put(None)
for q_str in repr(q), str(q):
self.assertNotIn("getters", q_str)
self.assertIn("putters[1]", q_str)
self.assertIn("tasks=2", q_str)
def test_order(self):
q = queues.Queue() # type: queues.Queue[int]
for i in [1, 3, 2]:
q.put_nowait(i)
items = [q.get_nowait() for _ in range(3)]
self.assertEqual([1, 3, 2], items)
@gen_test
def test_maxsize(self):
self.assertRaises(TypeError, queues.Queue, maxsize=None)
self.assertRaises(ValueError, queues.Queue, maxsize=-1)
q = queues.Queue(maxsize=2) # type: queues.Queue[int]
self.assertTrue(q.empty())
self.assertFalse(q.full())
self.assertEqual(2, q.maxsize)
self.assertTrue(q.put(0).done())
self.assertTrue(q.put(1).done())
self.assertFalse(q.empty())
self.assertTrue(q.full())
put2 = q.put(2)
self.assertFalse(put2.done())
self.assertEqual(0, (yield q.get())) # Make room.
self.assertTrue(put2.done())
self.assertFalse(q.empty())
self.assertTrue(q.full())
class QueueGetTest(AsyncTestCase):
@gen_test
def test_blocking_get(self):
q = queues.Queue() # type: queues.Queue[int]
q.put_nowait(0)
self.assertEqual(0, (yield q.get()))
def test_nonblocking_get(self):
q = queues.Queue() # type: queues.Queue[int]
q.put_nowait(0)
self.assertEqual(0, q.get_nowait())
def test_nonblocking_get_exception(self):
q = queues.Queue() # type: queues.Queue[int]
self.assertRaises(queues.QueueEmpty, q.get_nowait)
@gen_test
def test_get_with_putters(self):
q = queues.Queue(1) # type: queues.Queue[int]
q.put_nowait(0)
put = q.put(1)
self.assertEqual(0, (yield q.get()))
self.assertIsNone((yield put))
@gen_test
def test_blocking_get_wait(self):
q = queues.Queue() # type: queues.Queue[int]
q.put(0)
self.io_loop.call_later(0.01, q.put, 1)
self.io_loop.call_later(0.02, q.put, 2)
self.assertEqual(0, (yield q.get(timeout=timedelta(seconds=1))))
self.assertEqual(1, (yield q.get(timeout=timedelta(seconds=1))))
@gen_test
def test_get_timeout(self):
q = queues.Queue() # type: queues.Queue[int]
get_timeout = q.get(timeout=timedelta(seconds=0.01))
get = q.get()
with self.assertRaises(TimeoutError):
yield get_timeout
q.put_nowait(0)
self.assertEqual(0, (yield get))
@gen_test
def test_get_timeout_preempted(self):
q = queues.Queue() # type: queues.Queue[int]
get = q.get(timeout=timedelta(seconds=0.01))
q.put(0)
yield gen.sleep(0.02)
self.assertEqual(0, (yield get))
@gen_test
def test_get_clears_timed_out_putters(self):
q = queues.Queue(1) # type: queues.Queue[int]
# First putter succeeds, remainder block.
putters = [q.put(i, timedelta(seconds=0.01)) for i in range(10)]
put = q.put(10)
self.assertEqual(10, len(q._putters))
yield gen.sleep(0.02)
self.assertEqual(10, len(q._putters))
self.assertFalse(put.done()) # Final waiter is still active.
q.put(11)
self.assertEqual(0, (yield q.get())) # get() clears the waiters.
self.assertEqual(1, len(q._putters))
for putter in putters[1:]:
self.assertRaises(TimeoutError, putter.result)
@gen_test
def test_get_clears_timed_out_getters(self):
q = queues.Queue() # type: queues.Queue[int]
getters = [
asyncio.ensure_future(q.get(timedelta(seconds=0.01))) for _ in range(10)
]
get = asyncio.ensure_future(q.get())
self.assertEqual(11, len(q._getters))
yield gen.sleep(0.02)
self.assertEqual(11, len(q._getters))
self.assertFalse(get.done()) # Final waiter is still active.
q.get() # get() clears the waiters.
self.assertEqual(2, len(q._getters))
for getter in getters:
self.assertRaises(TimeoutError, getter.result)
@gen_test
def test_async_for(self):
q = queues.Queue() # type: queues.Queue[int]
for i in range(5):
q.put(i)
async def f():
results = []
async for i in q:
results.append(i)
if i == 4:
return results
results = yield f()
self.assertEqual(results, list(range(5)))
class QueuePutTest(AsyncTestCase):
@gen_test
def test_blocking_put(self):
q = queues.Queue() # type: queues.Queue[int]
q.put(0)
self.assertEqual(0, q.get_nowait())
def test_nonblocking_put_exception(self):
q = queues.Queue(1) # type: queues.Queue[int]
q.put(0)
self.assertRaises(queues.QueueFull, q.put_nowait, 1)
@gen_test
def test_put_with_getters(self):
q = queues.Queue() # type: queues.Queue[int]
get0 = q.get()
get1 = q.get()
yield q.put(0)
self.assertEqual(0, (yield get0))
yield q.put(1)
self.assertEqual(1, (yield get1))
@gen_test
def test_nonblocking_put_with_getters(self):
q = queues.Queue() # type: queues.Queue[int]
get0 = q.get()
get1 = q.get()
q.put_nowait(0)
# put_nowait does *not* immediately unblock getters.
yield gen.moment
self.assertEqual(0, (yield get0))
q.put_nowait(1)
yield gen.moment
self.assertEqual(1, (yield get1))
@gen_test
def test_blocking_put_wait(self):
q = queues.Queue(1) # type: queues.Queue[int]
q.put_nowait(0)
self.io_loop.call_later(0.01, q.get)
self.io_loop.call_later(0.02, q.get)
futures = [q.put(0), q.put(1)]
self.assertFalse(any(f.done() for f in futures))
yield futures
@gen_test
def test_put_timeout(self):
q = queues.Queue(1) # type: queues.Queue[int]
q.put_nowait(0) # Now it's full.
put_timeout = q.put(1, timeout=timedelta(seconds=0.01))
put = q.put(2)
with self.assertRaises(TimeoutError):
yield put_timeout
self.assertEqual(0, q.get_nowait())
# 1 was never put in the queue.
self.assertEqual(2, (yield q.get()))
# Final get() unblocked this putter.
yield put
@gen_test
def test_put_timeout_preempted(self):
q = queues.Queue(1) # type: queues.Queue[int]
q.put_nowait(0)
put = q.put(1, timeout=timedelta(seconds=0.01))
q.get()
yield gen.sleep(0.02)
yield put # No TimeoutError.
@gen_test
def test_put_clears_timed_out_putters(self):
q = queues.Queue(1) # type: queues.Queue[int]
# First putter succeeds, remainder block.
putters = [q.put(i, timedelta(seconds=0.01)) for i in range(10)]
put = q.put(10)
self.assertEqual(10, len(q._putters))
yield gen.sleep(0.02)
self.assertEqual(10, len(q._putters))
self.assertFalse(put.done()) # Final waiter is still active.
q.put(11) # put() clears the waiters.
self.assertEqual(2, len(q._putters))
for putter in putters[1:]:
self.assertRaises(TimeoutError, putter.result)
@gen_test
def test_put_clears_timed_out_getters(self):
q = queues.Queue() # type: queues.Queue[int]
getters = [
asyncio.ensure_future(q.get(timedelta(seconds=0.01))) for _ in range(10)
]
get = asyncio.ensure_future(q.get())
q.get()
self.assertEqual(12, len(q._getters))
yield gen.sleep(0.02)
self.assertEqual(12, len(q._getters))
self.assertFalse(get.done()) # Final waiters still active.
q.put(0) # put() clears the waiters.
self.assertEqual(1, len(q._getters))
self.assertEqual(0, (yield get))
for getter in getters:
self.assertRaises(TimeoutError, getter.result)
@gen_test
def test_float_maxsize(self):
# If a float is passed for maxsize, a reasonable limit should
# be enforced, instead of being treated as unlimited.
# It happens to be rounded up.
# http://bugs.python.org/issue21723
q = queues.Queue(maxsize=1.3) # type: ignore
self.assertTrue(q.empty())
self.assertFalse(q.full())
q.put_nowait(0)
q.put_nowait(1)
self.assertFalse(q.empty())
self.assertTrue(q.full())
self.assertRaises(queues.QueueFull, q.put_nowait, 2)
self.assertEqual(0, q.get_nowait())
self.assertFalse(q.empty())
self.assertFalse(q.full())
yield q.put(2)
put = q.put(3)
self.assertFalse(put.done())
self.assertEqual(1, (yield q.get()))
yield put
self.assertTrue(q.full())
class QueueJoinTest(AsyncTestCase):
queue_class = queues.Queue
def test_task_done_underflow(self):
q = self.queue_class()
self.assertRaises(ValueError, q.task_done)
@gen_test
def test_task_done(self):
q = self.queue_class()
for i in range(100):
q.put_nowait(i)
self.accumulator = 0
@gen.coroutine
def worker():
while True:
item = yield q.get()
self.accumulator += item
q.task_done()
yield gen.sleep(random() * 0.01)
# Two coroutines share work.
worker()
worker()
yield q.join()
self.assertEqual(sum(range(100)), self.accumulator)
@gen_test
def test_task_done_delay(self):
# Verify it is task_done(), not get(), that unblocks join().
q = self.queue_class()
q.put_nowait(0)
join = q.join()
self.assertFalse(join.done())
yield q.get()
self.assertFalse(join.done())
yield gen.moment
self.assertFalse(join.done())
q.task_done()
self.assertTrue(join.done())
@gen_test
def test_join_empty_queue(self):
q = self.queue_class()
yield q.join()
yield q.join()
@gen_test
def test_join_timeout(self):
q = self.queue_class()
q.put(0)
with self.assertRaises(TimeoutError):
yield q.join(timeout=timedelta(seconds=0.01))
class PriorityQueueJoinTest(QueueJoinTest):
queue_class = queues.PriorityQueue
@gen_test
def test_order(self):
q = self.queue_class(maxsize=2)
q.put_nowait((1, "a"))
q.put_nowait((0, "b"))
self.assertTrue(q.full())
q.put((3, "c"))
q.put((2, "d"))
self.assertEqual((0, "b"), q.get_nowait())
self.assertEqual((1, "a"), (yield q.get()))
self.assertEqual((2, "d"), q.get_nowait())
self.assertEqual((3, "c"), (yield q.get()))
self.assertTrue(q.empty())
class LifoQueueJoinTest(QueueJoinTest):
queue_class = queues.LifoQueue
@gen_test
def test_order(self):
q = self.queue_class(maxsize=2)
q.put_nowait(1)
q.put_nowait(0)
self.assertTrue(q.full())
q.put(3)
q.put(2)
self.assertEqual(3, q.get_nowait())
self.assertEqual(2, (yield q.get()))
self.assertEqual(0, q.get_nowait())
self.assertEqual(1, (yield q.get()))
self.assertTrue(q.empty())
class ProducerConsumerTest(AsyncTestCase):
@gen_test
def test_producer_consumer(self):
q = queues.Queue(maxsize=3) # type: queues.Queue[int]
history = []
# We don't yield between get() and task_done(), so get() must wait for
# the next tick. Otherwise we'd immediately call task_done and unblock
# join() before q.put() resumes, and we'd only process the first four
# items.
@gen.coroutine
def consumer():
while True:
history.append((yield q.get()))
q.task_done()
@gen.coroutine
def producer():
for item in range(10):
yield q.put(item)
consumer()
yield producer()
yield q.join()
self.assertEqual(list(range(10)), history)
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,10 @@
from tornado.ioloop import IOLoop
from tornado.netutil import ThreadedResolver
# When this module is imported, it runs getaddrinfo on a thread. Since
# the hostname is unicode, getaddrinfo attempts to import encodings.idna
# but blocks on the import lock. Verify that ThreadedResolver avoids
# this deadlock.
resolver = ThreadedResolver()
IOLoop.current().run_sync(lambda: resolver.resolve(u"localhost", 80))

View File

@@ -0,0 +1,276 @@
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from tornado.httputil import (
HTTPHeaders,
HTTPMessageDelegate,
HTTPServerConnectionDelegate,
ResponseStartLine,
)
from tornado.routing import (
HostMatches,
PathMatches,
ReversibleRouter,
Router,
Rule,
RuleRouter,
)
from tornado.testing import AsyncHTTPTestCase
from tornado.web import Application, HTTPError, RequestHandler
from tornado.wsgi import WSGIContainer
import typing # noqa: F401
class BasicRouter(Router):
def find_handler(self, request, **kwargs):
class MessageDelegate(HTTPMessageDelegate):
def __init__(self, connection):
self.connection = connection
def finish(self):
self.connection.write_headers(
ResponseStartLine("HTTP/1.1", 200, "OK"),
HTTPHeaders({"Content-Length": "2"}),
b"OK",
)
self.connection.finish()
return MessageDelegate(request.connection)
class BasicRouterTestCase(AsyncHTTPTestCase):
def get_app(self):
return BasicRouter()
def test_basic_router(self):
response = self.fetch("/any_request")
self.assertEqual(response.body, b"OK")
resources = {} # type: typing.Dict[str, bytes]
class GetResource(RequestHandler):
def get(self, path):
if path not in resources:
raise HTTPError(404)
self.finish(resources[path])
class PostResource(RequestHandler):
def post(self, path):
resources[path] = self.request.body
class HTTPMethodRouter(Router):
def __init__(self, app):
self.app = app
def find_handler(self, request, **kwargs):
handler = GetResource if request.method == "GET" else PostResource
return self.app.get_handler_delegate(request, handler, path_args=[request.path])
class HTTPMethodRouterTestCase(AsyncHTTPTestCase):
def get_app(self):
return HTTPMethodRouter(Application())
def test_http_method_router(self):
response = self.fetch("/post_resource", method="POST", body="data")
self.assertEqual(response.code, 200)
response = self.fetch("/get_resource")
self.assertEqual(response.code, 404)
response = self.fetch("/post_resource")
self.assertEqual(response.code, 200)
self.assertEqual(response.body, b"data")
def _get_named_handler(handler_name):
class Handler(RequestHandler):
def get(self, *args, **kwargs):
if self.application.settings.get("app_name") is not None:
self.write(self.application.settings["app_name"] + ": ")
self.finish(handler_name + ": " + self.reverse_url(handler_name))
return Handler
FirstHandler = _get_named_handler("first_handler")
SecondHandler = _get_named_handler("second_handler")
class CustomRouter(ReversibleRouter):
def __init__(self):
super(CustomRouter, self).__init__()
self.routes = {} # type: typing.Dict[str, typing.Any]
def add_routes(self, routes):
self.routes.update(routes)
def find_handler(self, request, **kwargs):
if request.path in self.routes:
app, handler = self.routes[request.path]
return app.get_handler_delegate(request, handler)
def reverse_url(self, name, *args):
handler_path = "/" + name
return handler_path if handler_path in self.routes else None
class CustomRouterTestCase(AsyncHTTPTestCase):
def get_app(self):
router = CustomRouter()
class CustomApplication(Application):
def reverse_url(self, name, *args):
return router.reverse_url(name, *args)
app1 = CustomApplication(app_name="app1")
app2 = CustomApplication(app_name="app2")
router.add_routes(
{
"/first_handler": (app1, FirstHandler),
"/second_handler": (app2, SecondHandler),
"/first_handler_second_app": (app2, FirstHandler),
}
)
return router
def test_custom_router(self):
response = self.fetch("/first_handler")
self.assertEqual(response.body, b"app1: first_handler: /first_handler")
response = self.fetch("/second_handler")
self.assertEqual(response.body, b"app2: second_handler: /second_handler")
response = self.fetch("/first_handler_second_app")
self.assertEqual(response.body, b"app2: first_handler: /first_handler")
class ConnectionDelegate(HTTPServerConnectionDelegate):
def start_request(self, server_conn, request_conn):
class MessageDelegate(HTTPMessageDelegate):
def __init__(self, connection):
self.connection = connection
def finish(self):
response_body = b"OK"
self.connection.write_headers(
ResponseStartLine("HTTP/1.1", 200, "OK"),
HTTPHeaders({"Content-Length": str(len(response_body))}),
)
self.connection.write(response_body)
self.connection.finish()
return MessageDelegate(request_conn)
class RuleRouterTest(AsyncHTTPTestCase):
def get_app(self):
app = Application()
def request_callable(request):
request.connection.write_headers(
ResponseStartLine("HTTP/1.1", 200, "OK"),
HTTPHeaders({"Content-Length": "2"}),
)
request.connection.write(b"OK")
request.connection.finish()
router = CustomRouter()
router.add_routes(
{"/nested_handler": (app, _get_named_handler("nested_handler"))}
)
app.add_handlers(
".*",
[
(
HostMatches("www.example.com"),
[
(
PathMatches("/first_handler"),
"tornado.test.routing_test.SecondHandler",
{},
"second_handler",
)
],
),
Rule(PathMatches("/.*handler"), router),
Rule(PathMatches("/first_handler"), FirstHandler, name="first_handler"),
Rule(PathMatches("/request_callable"), request_callable),
("/connection_delegate", ConnectionDelegate()),
],
)
return app
def test_rule_based_router(self):
response = self.fetch("/first_handler")
self.assertEqual(response.body, b"first_handler: /first_handler")
response = self.fetch("/first_handler", headers={"Host": "www.example.com"})
self.assertEqual(response.body, b"second_handler: /first_handler")
response = self.fetch("/nested_handler")
self.assertEqual(response.body, b"nested_handler: /nested_handler")
response = self.fetch("/nested_not_found_handler")
self.assertEqual(response.code, 404)
response = self.fetch("/connection_delegate")
self.assertEqual(response.body, b"OK")
response = self.fetch("/request_callable")
self.assertEqual(response.body, b"OK")
response = self.fetch("/404")
self.assertEqual(response.code, 404)
class WSGIContainerTestCase(AsyncHTTPTestCase):
def get_app(self):
wsgi_app = WSGIContainer(self.wsgi_app)
class Handler(RequestHandler):
def get(self, *args, **kwargs):
self.finish(self.reverse_url("tornado"))
return RuleRouter(
[
(
PathMatches("/tornado.*"),
Application([(r"/tornado/test", Handler, {}, "tornado")]),
),
(PathMatches("/wsgi"), wsgi_app),
]
)
def wsgi_app(self, environ, start_response):
start_response("200 OK", [])
return [b"WSGI"]
def test_wsgi_container(self):
response = self.fetch("/tornado/test")
self.assertEqual(response.body, b"/tornado/test")
response = self.fetch("/wsgi")
self.assertEqual(response.body, b"WSGI")
def test_delegate_not_found(self):
response = self.fetch("/404")
self.assertEqual(response.code, 404)

View File

@@ -0,0 +1,236 @@
from functools import reduce
import gc
import io
import locale # system locale module, not tornado.locale
import logging
import operator
import textwrap
import sys
import unittest
import warnings
from tornado.httpclient import AsyncHTTPClient
from tornado.httpserver import HTTPServer
from tornado.netutil import Resolver
from tornado.options import define, add_parse_callback
TEST_MODULES = [
"tornado.httputil.doctests",
"tornado.iostream.doctests",
"tornado.util.doctests",
"tornado.test.asyncio_test",
"tornado.test.auth_test",
"tornado.test.autoreload_test",
"tornado.test.concurrent_test",
"tornado.test.curl_httpclient_test",
"tornado.test.escape_test",
"tornado.test.gen_test",
"tornado.test.http1connection_test",
"tornado.test.httpclient_test",
"tornado.test.httpserver_test",
"tornado.test.httputil_test",
"tornado.test.import_test",
"tornado.test.ioloop_test",
"tornado.test.iostream_test",
"tornado.test.locale_test",
"tornado.test.locks_test",
"tornado.test.netutil_test",
"tornado.test.log_test",
"tornado.test.options_test",
"tornado.test.process_test",
"tornado.test.queues_test",
"tornado.test.routing_test",
"tornado.test.simple_httpclient_test",
"tornado.test.tcpclient_test",
"tornado.test.tcpserver_test",
"tornado.test.template_test",
"tornado.test.testing_test",
"tornado.test.twisted_test",
"tornado.test.util_test",
"tornado.test.web_test",
"tornado.test.websocket_test",
"tornado.test.windows_test",
"tornado.test.wsgi_test",
]
def all():
return unittest.defaultTestLoader.loadTestsFromNames(TEST_MODULES)
def test_runner_factory(stderr):
class TornadoTextTestRunner(unittest.TextTestRunner):
def __init__(self, *args, **kwargs):
kwargs["stream"] = stderr
super(TornadoTextTestRunner, self).__init__(*args, **kwargs)
def run(self, test):
result = super(TornadoTextTestRunner, self).run(test)
if result.skipped:
skip_reasons = set(reason for (test, reason) in result.skipped)
self.stream.write(
textwrap.fill(
"Some tests were skipped because: %s"
% ", ".join(sorted(skip_reasons))
)
)
self.stream.write("\n")
return result
return TornadoTextTestRunner
class LogCounter(logging.Filter):
"""Counts the number of WARNING or higher log records."""
def __init__(self, *args, **kwargs):
super(LogCounter, self).__init__(*args, **kwargs)
self.info_count = self.warning_count = self.error_count = 0
def filter(self, record):
if record.levelno >= logging.ERROR:
self.error_count += 1
elif record.levelno >= logging.WARNING:
self.warning_count += 1
elif record.levelno >= logging.INFO:
self.info_count += 1
return True
class CountingStderr(io.IOBase):
def __init__(self, real):
self.real = real
self.byte_count = 0
def write(self, data):
self.byte_count += len(data)
return self.real.write(data)
def flush(self):
return self.real.flush()
def main():
# Be strict about most warnings (This is set in our test running
# scripts to catch import-time warnings, but set it again here to
# be sure). This also turns on warnings that are ignored by
# default, including DeprecationWarnings and python 3.2's
# ResourceWarnings.
warnings.filterwarnings("error")
# setuptools sometimes gives ImportWarnings about things that are on
# sys.path even if they're not being used.
warnings.filterwarnings("ignore", category=ImportWarning)
# Tornado generally shouldn't use anything deprecated, but some of
# our dependencies do (last match wins).
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("error", category=DeprecationWarning, module=r"tornado\..*")
warnings.filterwarnings("ignore", category=PendingDeprecationWarning)
warnings.filterwarnings(
"error", category=PendingDeprecationWarning, module=r"tornado\..*"
)
# The unittest module is aggressive about deprecating redundant methods,
# leaving some without non-deprecated spellings that work on both
# 2.7 and 3.2
warnings.filterwarnings(
"ignore", category=DeprecationWarning, message="Please use assert.* instead"
)
warnings.filterwarnings(
"ignore",
category=PendingDeprecationWarning,
message="Please use assert.* instead",
)
# Twisted 15.0.0 triggers some warnings on py3 with -bb.
warnings.filterwarnings("ignore", category=BytesWarning, module=r"twisted\..*")
if (3,) < sys.version_info < (3, 6):
# Prior to 3.6, async ResourceWarnings were rather noisy
# and even
# `python3.4 -W error -c 'import asyncio; asyncio.get_event_loop()'`
# would generate a warning.
warnings.filterwarnings(
"ignore", category=ResourceWarning, module=r"asyncio\..*"
)
# This deprecation warning is introduced in Python 3.8 and is
# triggered by pycurl. Unforunately, because it is raised in the C
# layer it can't be filtered by module and we must match the
# message text instead (Tornado's C module uses PY_SSIZE_T_CLEAN
# so it's not at risk of running into this issue).
warnings.filterwarnings(
"ignore",
category=DeprecationWarning,
message="PY_SSIZE_T_CLEAN will be required",
)
logging.getLogger("tornado.access").setLevel(logging.CRITICAL)
define(
"httpclient",
type=str,
default=None,
callback=lambda s: AsyncHTTPClient.configure(
s, defaults=dict(allow_ipv6=False)
),
)
define("httpserver", type=str, default=None, callback=HTTPServer.configure)
define("resolver", type=str, default=None, callback=Resolver.configure)
define(
"debug_gc",
type=str,
multiple=True,
help="A comma-separated list of gc module debug constants, "
"e.g. DEBUG_STATS or DEBUG_COLLECTABLE,DEBUG_OBJECTS",
callback=lambda values: gc.set_debug(
reduce(operator.or_, (getattr(gc, v) for v in values))
),
)
def set_locale(x):
locale.setlocale(locale.LC_ALL, x)
define("locale", type=str, default=None, callback=set_locale)
log_counter = LogCounter()
add_parse_callback(lambda: logging.getLogger().handlers[0].addFilter(log_counter))
# Certain errors (especially "unclosed resource" errors raised in
# destructors) go directly to stderr instead of logging. Count
# anything written by anything but the test runner as an error.
orig_stderr = sys.stderr
counting_stderr = CountingStderr(orig_stderr)
sys.stderr = counting_stderr # type: ignore
import tornado.testing
kwargs = {}
# HACK: unittest.main will make its own changes to the warning
# configuration, which may conflict with the settings above
# or command-line flags like -bb. Passing warnings=False
# suppresses this behavior, although this looks like an implementation
# detail. http://bugs.python.org/issue15626
kwargs["warnings"] = False
kwargs["testRunner"] = test_runner_factory(orig_stderr)
try:
tornado.testing.main(**kwargs)
finally:
# The tests should run clean; consider it a failure if they
# logged anything at info level or above.
if (
log_counter.info_count > 0
or log_counter.warning_count > 0
or log_counter.error_count > 0
or counting_stderr.byte_count > 0
):
logging.error(
"logged %d infos, %d warnings, %d errors, and %d bytes to stderr",
log_counter.info_count,
log_counter.warning_count,
log_counter.error_count,
counting_stderr.byte_count,
)
sys.exit(1)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,832 @@
import collections
from contextlib import closing
import errno
import gzip
import logging
import os
import re
import socket
import ssl
import sys
import typing # noqa: F401
from tornado.escape import to_unicode, utf8
from tornado import gen
from tornado.httpclient import AsyncHTTPClient
from tornado.httputil import HTTPHeaders, ResponseStartLine
from tornado.ioloop import IOLoop
from tornado.iostream import UnsatisfiableReadError
from tornado.locks import Event
from tornado.log import gen_log
from tornado.netutil import Resolver, bind_sockets
from tornado.simple_httpclient import (
SimpleAsyncHTTPClient,
HTTPStreamClosedError,
HTTPTimeoutError,
)
from tornado.test.httpclient_test import (
ChunkHandler,
CountdownHandler,
HelloWorldHandler,
RedirectHandler,
)
from tornado.test import httpclient_test
from tornado.testing import (
AsyncHTTPTestCase,
AsyncHTTPSTestCase,
AsyncTestCase,
ExpectLog,
gen_test,
)
from tornado.test.util import skipOnTravis, skipIfNoIPv6, refusing_port
from tornado.web import RequestHandler, Application, url, stream_request_body
class SimpleHTTPClientCommonTestCase(httpclient_test.HTTPClientCommonTestCase):
def get_http_client(self):
client = SimpleAsyncHTTPClient(force_instance=True)
self.assertTrue(isinstance(client, SimpleAsyncHTTPClient))
return client
class TriggerHandler(RequestHandler):
def initialize(self, queue, wake_callback):
self.queue = queue
self.wake_callback = wake_callback
@gen.coroutine
def get(self):
logging.debug("queuing trigger")
event = Event()
self.queue.append(event.set)
if self.get_argument("wake", "true") == "true":
self.wake_callback()
yield event.wait()
class ContentLengthHandler(RequestHandler):
def get(self):
self.stream = self.detach()
IOLoop.current().spawn_callback(self.write_response)
@gen.coroutine
def write_response(self):
yield self.stream.write(
utf8(
"HTTP/1.0 200 OK\r\nContent-Length: %s\r\n\r\nok"
% self.get_argument("value")
)
)
self.stream.close()
class HeadHandler(RequestHandler):
def head(self):
self.set_header("Content-Length", "7")
class OptionsHandler(RequestHandler):
def options(self):
self.set_header("Access-Control-Allow-Origin", "*")
self.write("ok")
class NoContentHandler(RequestHandler):
def get(self):
self.set_status(204)
self.finish()
class SeeOtherPostHandler(RequestHandler):
def post(self):
redirect_code = int(self.request.body)
assert redirect_code in (302, 303), "unexpected body %r" % self.request.body
self.set_header("Location", "/see_other_get")
self.set_status(redirect_code)
class SeeOtherGetHandler(RequestHandler):
def get(self):
if self.request.body:
raise Exception("unexpected body %r" % self.request.body)
self.write("ok")
class HostEchoHandler(RequestHandler):
def get(self):
self.write(self.request.headers["Host"])
class NoContentLengthHandler(RequestHandler):
def get(self):
if self.request.version.startswith("HTTP/1"):
# Emulate the old HTTP/1.0 behavior of returning a body with no
# content-length. Tornado handles content-length at the framework
# level so we have to go around it.
stream = self.detach()
stream.write(b"HTTP/1.0 200 OK\r\n\r\n" b"hello")
stream.close()
else:
self.finish("HTTP/1 required")
class EchoPostHandler(RequestHandler):
def post(self):
self.write(self.request.body)
@stream_request_body
class RespondInPrepareHandler(RequestHandler):
def prepare(self):
self.set_status(403)
self.finish("forbidden")
class SimpleHTTPClientTestMixin(object):
def get_app(self):
# callable objects to finish pending /trigger requests
self.triggers = collections.deque() # type: typing.Deque[str]
return Application(
[
url(
"/trigger",
TriggerHandler,
dict(queue=self.triggers, wake_callback=self.stop),
),
url("/chunk", ChunkHandler),
url("/countdown/([0-9]+)", CountdownHandler, name="countdown"),
url("/hello", HelloWorldHandler),
url("/content_length", ContentLengthHandler),
url("/head", HeadHandler),
url("/options", OptionsHandler),
url("/no_content", NoContentHandler),
url("/see_other_post", SeeOtherPostHandler),
url("/see_other_get", SeeOtherGetHandler),
url("/host_echo", HostEchoHandler),
url("/no_content_length", NoContentLengthHandler),
url("/echo_post", EchoPostHandler),
url("/respond_in_prepare", RespondInPrepareHandler),
url("/redirect", RedirectHandler),
],
gzip=True,
)
def test_singleton(self):
# Class "constructor" reuses objects on the same IOLoop
self.assertTrue(SimpleAsyncHTTPClient() is SimpleAsyncHTTPClient())
# unless force_instance is used
self.assertTrue(
SimpleAsyncHTTPClient() is not SimpleAsyncHTTPClient(force_instance=True)
)
# different IOLoops use different objects
with closing(IOLoop()) as io_loop2:
async def make_client():
await gen.sleep(0)
return SimpleAsyncHTTPClient()
client1 = self.io_loop.run_sync(make_client)
client2 = io_loop2.run_sync(make_client)
self.assertTrue(client1 is not client2)
def test_connection_limit(self):
with closing(self.create_client(max_clients=2)) as client:
self.assertEqual(client.max_clients, 2)
seen = []
# Send 4 requests. Two can be sent immediately, while the others
# will be queued
for i in range(4):
def cb(fut, i=i):
seen.append(i)
self.stop()
client.fetch(self.get_url("/trigger")).add_done_callback(cb)
self.wait(condition=lambda: len(self.triggers) == 2)
self.assertEqual(len(client.queue), 2)
# Finish the first two requests and let the next two through
self.triggers.popleft()()
self.triggers.popleft()()
self.wait(condition=lambda: (len(self.triggers) == 2 and len(seen) == 2))
self.assertEqual(set(seen), set([0, 1]))
self.assertEqual(len(client.queue), 0)
# Finish all the pending requests
self.triggers.popleft()()
self.triggers.popleft()()
self.wait(condition=lambda: len(seen) == 4)
self.assertEqual(set(seen), set([0, 1, 2, 3]))
self.assertEqual(len(self.triggers), 0)
@gen_test
def test_redirect_connection_limit(self):
# following redirects should not consume additional connections
with closing(self.create_client(max_clients=1)) as client:
response = yield client.fetch(self.get_url("/countdown/3"), max_redirects=3)
response.rethrow()
def test_gzip(self):
# All the tests in this file should be using gzip, but this test
# ensures that it is in fact getting compressed.
# Setting Accept-Encoding manually bypasses the client's
# decompression so we can see the raw data.
response = self.fetch(
"/chunk", use_gzip=False, headers={"Accept-Encoding": "gzip"}
)
self.assertEqual(response.headers["Content-Encoding"], "gzip")
self.assertNotEqual(response.body, b"asdfqwer")
# Our test data gets bigger when gzipped. Oops. :)
# Chunked encoding bypasses the MIN_LENGTH check.
self.assertEqual(len(response.body), 34)
f = gzip.GzipFile(mode="r", fileobj=response.buffer)
self.assertEqual(f.read(), b"asdfqwer")
def test_max_redirects(self):
response = self.fetch("/countdown/5", max_redirects=3)
self.assertEqual(302, response.code)
# We requested 5, followed three redirects for 4, 3, 2, then the last
# unfollowed redirect is to 1.
self.assertTrue(response.request.url.endswith("/countdown/5"))
self.assertTrue(response.effective_url.endswith("/countdown/2"))
self.assertTrue(response.headers["Location"].endswith("/countdown/1"))
def test_header_reuse(self):
# Apps may reuse a headers object if they are only passing in constant
# headers like user-agent. The header object should not be modified.
headers = HTTPHeaders({"User-Agent": "Foo"})
self.fetch("/hello", headers=headers)
self.assertEqual(list(headers.get_all()), [("User-Agent", "Foo")])
def test_see_other_redirect(self):
for code in (302, 303):
response = self.fetch("/see_other_post", method="POST", body="%d" % code)
self.assertEqual(200, response.code)
self.assertTrue(response.request.url.endswith("/see_other_post"))
self.assertTrue(response.effective_url.endswith("/see_other_get"))
# request is the original request, is a POST still
self.assertEqual("POST", response.request.method)
@skipOnTravis
@gen_test
def test_connect_timeout(self):
timeout = 0.1
cleanup_event = Event()
test = self
class TimeoutResolver(Resolver):
async def resolve(self, *args, **kwargs):
await cleanup_event.wait()
# Return something valid so the test doesn't raise during shutdown.
return [(socket.AF_INET, ("127.0.0.1", test.get_http_port()))]
with closing(self.create_client(resolver=TimeoutResolver())) as client:
with self.assertRaises(HTTPTimeoutError):
yield client.fetch(
self.get_url("/hello"),
connect_timeout=timeout,
request_timeout=3600,
raise_error=True,
)
# Let the hanging coroutine clean up after itself. We need to
# wait more than a single IOLoop iteration for the SSL case,
# which logs errors on unexpected EOF.
cleanup_event.set()
yield gen.sleep(0.2)
@skipOnTravis
def test_request_timeout(self):
timeout = 0.1
if os.name == "nt":
timeout = 0.5
with self.assertRaises(HTTPTimeoutError):
self.fetch("/trigger?wake=false", request_timeout=timeout, raise_error=True)
# trigger the hanging request to let it clean up after itself
self.triggers.popleft()()
self.io_loop.run_sync(lambda: gen.sleep(0))
@skipIfNoIPv6
def test_ipv6(self):
[sock] = bind_sockets(0, "::1", family=socket.AF_INET6)
port = sock.getsockname()[1]
self.http_server.add_socket(sock)
url = "%s://[::1]:%d/hello" % (self.get_protocol(), port)
# ipv6 is currently enabled by default but can be disabled
with self.assertRaises(Exception):
self.fetch(url, allow_ipv6=False, raise_error=True)
response = self.fetch(url)
self.assertEqual(response.body, b"Hello world!")
def test_multiple_content_length_accepted(self):
response = self.fetch("/content_length?value=2,2")
self.assertEqual(response.body, b"ok")
response = self.fetch("/content_length?value=2,%202,2")
self.assertEqual(response.body, b"ok")
with ExpectLog(gen_log, ".*Multiple unequal Content-Lengths"):
with self.assertRaises(HTTPStreamClosedError):
self.fetch("/content_length?value=2,4", raise_error=True)
with self.assertRaises(HTTPStreamClosedError):
self.fetch("/content_length?value=2,%202,3", raise_error=True)
def test_head_request(self):
response = self.fetch("/head", method="HEAD")
self.assertEqual(response.code, 200)
self.assertEqual(response.headers["content-length"], "7")
self.assertFalse(response.body)
def test_options_request(self):
response = self.fetch("/options", method="OPTIONS")
self.assertEqual(response.code, 200)
self.assertEqual(response.headers["content-length"], "2")
self.assertEqual(response.headers["access-control-allow-origin"], "*")
self.assertEqual(response.body, b"ok")
def test_no_content(self):
response = self.fetch("/no_content")
self.assertEqual(response.code, 204)
# 204 status shouldn't have a content-length
#
# Tests with a content-length header are included below
# in HTTP204NoContentTestCase.
self.assertNotIn("Content-Length", response.headers)
def test_host_header(self):
host_re = re.compile(b"^127.0.0.1:[0-9]+$")
response = self.fetch("/host_echo")
self.assertTrue(host_re.match(response.body))
url = self.get_url("/host_echo").replace("http://", "http://me:secret@")
response = self.fetch(url)
self.assertTrue(host_re.match(response.body), response.body)
def test_connection_refused(self):
cleanup_func, port = refusing_port()
self.addCleanup(cleanup_func)
with ExpectLog(gen_log, ".*", required=False):
with self.assertRaises(socket.error) as cm:
self.fetch("http://127.0.0.1:%d/" % port, raise_error=True)
if sys.platform != "cygwin":
# cygwin returns EPERM instead of ECONNREFUSED here
contains_errno = str(errno.ECONNREFUSED) in str(cm.exception)
if not contains_errno and hasattr(errno, "WSAECONNREFUSED"):
contains_errno = str(errno.WSAECONNREFUSED) in str( # type: ignore
cm.exception
)
self.assertTrue(contains_errno, cm.exception)
# This is usually "Connection refused".
# On windows, strerror is broken and returns "Unknown error".
expected_message = os.strerror(errno.ECONNREFUSED)
self.assertTrue(expected_message in str(cm.exception), cm.exception)
def test_queue_timeout(self):
with closing(self.create_client(max_clients=1)) as client:
# Wait for the trigger request to block, not complete.
fut1 = client.fetch(self.get_url("/trigger"), request_timeout=10)
self.wait()
with self.assertRaises(HTTPTimeoutError) as cm:
self.io_loop.run_sync(
lambda: client.fetch(
self.get_url("/hello"), connect_timeout=0.1, raise_error=True
)
)
self.assertEqual(str(cm.exception), "Timeout in request queue")
self.triggers.popleft()()
self.io_loop.run_sync(lambda: fut1)
def test_no_content_length(self):
response = self.fetch("/no_content_length")
if response.body == b"HTTP/1 required":
self.skipTest("requires HTTP/1.x")
else:
self.assertEquals(b"hello", response.body)
def sync_body_producer(self, write):
write(b"1234")
write(b"5678")
@gen.coroutine
def async_body_producer(self, write):
yield write(b"1234")
yield gen.moment
yield write(b"5678")
def test_sync_body_producer_chunked(self):
response = self.fetch(
"/echo_post", method="POST", body_producer=self.sync_body_producer
)
response.rethrow()
self.assertEqual(response.body, b"12345678")
def test_sync_body_producer_content_length(self):
response = self.fetch(
"/echo_post",
method="POST",
body_producer=self.sync_body_producer,
headers={"Content-Length": "8"},
)
response.rethrow()
self.assertEqual(response.body, b"12345678")
def test_async_body_producer_chunked(self):
response = self.fetch(
"/echo_post", method="POST", body_producer=self.async_body_producer
)
response.rethrow()
self.assertEqual(response.body, b"12345678")
def test_async_body_producer_content_length(self):
response = self.fetch(
"/echo_post",
method="POST",
body_producer=self.async_body_producer,
headers={"Content-Length": "8"},
)
response.rethrow()
self.assertEqual(response.body, b"12345678")
def test_native_body_producer_chunked(self):
async def body_producer(write):
await write(b"1234")
import asyncio
await asyncio.sleep(0)
await write(b"5678")
response = self.fetch("/echo_post", method="POST", body_producer=body_producer)
response.rethrow()
self.assertEqual(response.body, b"12345678")
def test_native_body_producer_content_length(self):
async def body_producer(write):
await write(b"1234")
import asyncio
await asyncio.sleep(0)
await write(b"5678")
response = self.fetch(
"/echo_post",
method="POST",
body_producer=body_producer,
headers={"Content-Length": "8"},
)
response.rethrow()
self.assertEqual(response.body, b"12345678")
def test_100_continue(self):
response = self.fetch(
"/echo_post", method="POST", body=b"1234", expect_100_continue=True
)
self.assertEqual(response.body, b"1234")
def test_100_continue_early_response(self):
def body_producer(write):
raise Exception("should not be called")
response = self.fetch(
"/respond_in_prepare",
method="POST",
body_producer=body_producer,
expect_100_continue=True,
)
self.assertEqual(response.code, 403)
def test_streaming_follow_redirects(self):
# When following redirects, header and streaming callbacks
# should only be called for the final result.
# TODO(bdarnell): this test belongs in httpclient_test instead of
# simple_httpclient_test, but it fails with the version of libcurl
# available on travis-ci. Move it when that has been upgraded
# or we have a better framework to skip tests based on curl version.
headers = [] # type: typing.List[str]
chunk_bytes = [] # type: typing.List[bytes]
self.fetch(
"/redirect?url=/hello",
header_callback=headers.append,
streaming_callback=chunk_bytes.append,
)
chunks = list(map(to_unicode, chunk_bytes))
self.assertEqual(chunks, ["Hello world!"])
# Make sure we only got one set of headers.
num_start_lines = len([h for h in headers if h.startswith("HTTP/")])
self.assertEqual(num_start_lines, 1)
class SimpleHTTPClientTestCase(SimpleHTTPClientTestMixin, AsyncHTTPTestCase):
def setUp(self):
super(SimpleHTTPClientTestCase, self).setUp()
self.http_client = self.create_client()
def create_client(self, **kwargs):
return SimpleAsyncHTTPClient(force_instance=True, **kwargs)
class SimpleHTTPSClientTestCase(SimpleHTTPClientTestMixin, AsyncHTTPSTestCase):
def setUp(self):
super(SimpleHTTPSClientTestCase, self).setUp()
self.http_client = self.create_client()
def create_client(self, **kwargs):
return SimpleAsyncHTTPClient(
force_instance=True, defaults=dict(validate_cert=False), **kwargs
)
def test_ssl_options(self):
resp = self.fetch("/hello", ssl_options={})
self.assertEqual(resp.body, b"Hello world!")
def test_ssl_context(self):
resp = self.fetch("/hello", ssl_options=ssl.SSLContext(ssl.PROTOCOL_SSLv23))
self.assertEqual(resp.body, b"Hello world!")
def test_ssl_options_handshake_fail(self):
with ExpectLog(gen_log, "SSL Error|Uncaught exception", required=False):
with self.assertRaises(ssl.SSLError):
self.fetch(
"/hello",
ssl_options=dict(cert_reqs=ssl.CERT_REQUIRED),
raise_error=True,
)
def test_ssl_context_handshake_fail(self):
with ExpectLog(gen_log, "SSL Error|Uncaught exception"):
ctx = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
ctx.verify_mode = ssl.CERT_REQUIRED
with self.assertRaises(ssl.SSLError):
self.fetch("/hello", ssl_options=ctx, raise_error=True)
def test_error_logging(self):
# No stack traces are logged for SSL errors (in this case,
# failure to validate the testing self-signed cert).
# The SSLError is exposed through ssl.SSLError.
with ExpectLog(gen_log, ".*") as expect_log:
with self.assertRaises(ssl.SSLError):
self.fetch("/", validate_cert=True, raise_error=True)
self.assertFalse(expect_log.logged_stack)
class CreateAsyncHTTPClientTestCase(AsyncTestCase):
def setUp(self):
super(CreateAsyncHTTPClientTestCase, self).setUp()
self.saved = AsyncHTTPClient._save_configuration()
def tearDown(self):
AsyncHTTPClient._restore_configuration(self.saved)
super(CreateAsyncHTTPClientTestCase, self).tearDown()
def test_max_clients(self):
AsyncHTTPClient.configure(SimpleAsyncHTTPClient)
with closing(AsyncHTTPClient(force_instance=True)) as client:
self.assertEqual(client.max_clients, 10) # type: ignore
with closing(AsyncHTTPClient(max_clients=11, force_instance=True)) as client:
self.assertEqual(client.max_clients, 11) # type: ignore
# Now configure max_clients statically and try overriding it
# with each way max_clients can be passed
AsyncHTTPClient.configure(SimpleAsyncHTTPClient, max_clients=12)
with closing(AsyncHTTPClient(force_instance=True)) as client:
self.assertEqual(client.max_clients, 12) # type: ignore
with closing(AsyncHTTPClient(max_clients=13, force_instance=True)) as client:
self.assertEqual(client.max_clients, 13) # type: ignore
with closing(AsyncHTTPClient(max_clients=14, force_instance=True)) as client:
self.assertEqual(client.max_clients, 14) # type: ignore
class HTTP100ContinueTestCase(AsyncHTTPTestCase):
def respond_100(self, request):
self.http1 = request.version.startswith("HTTP/1.")
if not self.http1:
request.connection.write_headers(
ResponseStartLine("", 200, "OK"), HTTPHeaders()
)
request.connection.finish()
return
self.request = request
fut = self.request.connection.stream.write(b"HTTP/1.1 100 CONTINUE\r\n\r\n")
fut.add_done_callback(self.respond_200)
def respond_200(self, fut):
fut.result()
fut = self.request.connection.stream.write(
b"HTTP/1.1 200 OK\r\nContent-Length: 1\r\n\r\nA"
)
fut.add_done_callback(lambda f: self.request.connection.stream.close())
def get_app(self):
# Not a full Application, but works as an HTTPServer callback
return self.respond_100
def test_100_continue(self):
res = self.fetch("/")
if not self.http1:
self.skipTest("requires HTTP/1.x")
self.assertEqual(res.body, b"A")
class HTTP204NoContentTestCase(AsyncHTTPTestCase):
def respond_204(self, request):
self.http1 = request.version.startswith("HTTP/1.")
if not self.http1:
# Close the request cleanly in HTTP/2; it will be skipped anyway.
request.connection.write_headers(
ResponseStartLine("", 200, "OK"), HTTPHeaders()
)
request.connection.finish()
return
# A 204 response never has a body, even if doesn't have a content-length
# (which would otherwise mean read-until-close). We simulate here a
# server that sends no content length and does not close the connection.
#
# Tests of a 204 response with no Content-Length header are included
# in SimpleHTTPClientTestMixin.
stream = request.connection.detach()
stream.write(b"HTTP/1.1 204 No content\r\n")
if request.arguments.get("error", [False])[-1]:
stream.write(b"Content-Length: 5\r\n")
else:
stream.write(b"Content-Length: 0\r\n")
stream.write(b"\r\n")
stream.close()
def get_app(self):
return self.respond_204
def test_204_no_content(self):
resp = self.fetch("/")
if not self.http1:
self.skipTest("requires HTTP/1.x")
self.assertEqual(resp.code, 204)
self.assertEqual(resp.body, b"")
def test_204_invalid_content_length(self):
# 204 status with non-zero content length is malformed
with ExpectLog(gen_log, ".*Response with code 204 should not have body"):
with self.assertRaises(HTTPStreamClosedError):
self.fetch("/?error=1", raise_error=True)
if not self.http1:
self.skipTest("requires HTTP/1.x")
if self.http_client.configured_class != SimpleAsyncHTTPClient:
self.skipTest("curl client accepts invalid headers")
class HostnameMappingTestCase(AsyncHTTPTestCase):
def setUp(self):
super(HostnameMappingTestCase, self).setUp()
self.http_client = SimpleAsyncHTTPClient(
hostname_mapping={
"www.example.com": "127.0.0.1",
("foo.example.com", 8000): ("127.0.0.1", self.get_http_port()),
}
)
def get_app(self):
return Application([url("/hello", HelloWorldHandler)])
def test_hostname_mapping(self):
response = self.fetch("http://www.example.com:%d/hello" % self.get_http_port())
response.rethrow()
self.assertEqual(response.body, b"Hello world!")
def test_port_mapping(self):
response = self.fetch("http://foo.example.com:8000/hello")
response.rethrow()
self.assertEqual(response.body, b"Hello world!")
class ResolveTimeoutTestCase(AsyncHTTPTestCase):
def setUp(self):
self.cleanup_event = Event()
test = self
# Dummy Resolver subclass that never finishes.
class BadResolver(Resolver):
@gen.coroutine
def resolve(self, *args, **kwargs):
yield test.cleanup_event.wait()
# Return something valid so the test doesn't raise during cleanup.
return [(socket.AF_INET, ("127.0.0.1", test.get_http_port()))]
super(ResolveTimeoutTestCase, self).setUp()
self.http_client = SimpleAsyncHTTPClient(resolver=BadResolver())
def get_app(self):
return Application([url("/hello", HelloWorldHandler)])
def test_resolve_timeout(self):
with self.assertRaises(HTTPTimeoutError):
self.fetch("/hello", connect_timeout=0.1, raise_error=True)
# Let the hanging coroutine clean up after itself
self.cleanup_event.set()
self.io_loop.run_sync(lambda: gen.sleep(0))
class MaxHeaderSizeTest(AsyncHTTPTestCase):
def get_app(self):
class SmallHeaders(RequestHandler):
def get(self):
self.set_header("X-Filler", "a" * 100)
self.write("ok")
class LargeHeaders(RequestHandler):
def get(self):
self.set_header("X-Filler", "a" * 1000)
self.write("ok")
return Application([("/small", SmallHeaders), ("/large", LargeHeaders)])
def get_http_client(self):
return SimpleAsyncHTTPClient(max_header_size=1024)
def test_small_headers(self):
response = self.fetch("/small")
response.rethrow()
self.assertEqual(response.body, b"ok")
def test_large_headers(self):
with ExpectLog(gen_log, "Unsatisfiable read"):
with self.assertRaises(UnsatisfiableReadError):
self.fetch("/large", raise_error=True)
class MaxBodySizeTest(AsyncHTTPTestCase):
def get_app(self):
class SmallBody(RequestHandler):
def get(self):
self.write("a" * 1024 * 64)
class LargeBody(RequestHandler):
def get(self):
self.write("a" * 1024 * 100)
return Application([("/small", SmallBody), ("/large", LargeBody)])
def get_http_client(self):
return SimpleAsyncHTTPClient(max_body_size=1024 * 64)
def test_small_body(self):
response = self.fetch("/small")
response.rethrow()
self.assertEqual(response.body, b"a" * 1024 * 64)
def test_large_body(self):
with ExpectLog(
gen_log, "Malformed HTTP message from None: Content-Length too long"
):
with self.assertRaises(HTTPStreamClosedError):
self.fetch("/large", raise_error=True)
class MaxBufferSizeTest(AsyncHTTPTestCase):
def get_app(self):
class LargeBody(RequestHandler):
def get(self):
self.write("a" * 1024 * 100)
return Application([("/large", LargeBody)])
def get_http_client(self):
# 100KB body with 64KB buffer
return SimpleAsyncHTTPClient(
max_body_size=1024 * 100, max_buffer_size=1024 * 64
)
def test_large_body(self):
response = self.fetch("/large")
response.rethrow()
self.assertEqual(response.body, b"a" * 1024 * 100)
class ChunkedWithContentLengthTest(AsyncHTTPTestCase):
def get_app(self):
class ChunkedWithContentLength(RequestHandler):
def get(self):
# Add an invalid Transfer-Encoding to the response
self.set_header("Transfer-Encoding", "chunked")
self.write("Hello world")
return Application([("/chunkwithcl", ChunkedWithContentLength)])
def get_http_client(self):
return SimpleAsyncHTTPClient()
def test_chunked_with_content_length(self):
# Make sure the invalid headers are detected
with ExpectLog(
gen_log,
(
"Malformed HTTP message from None: Response "
"with both Transfer-Encoding and Content-Length"
),
):
with self.assertRaises(HTTPStreamClosedError):
self.fetch("/chunkwithcl", raise_error=True)

View File

@@ -0,0 +1 @@
this is the index

View File

@@ -0,0 +1,2 @@
User-agent: *
Disallow: /

View File

@@ -0,0 +1,23 @@
<?xml version="1.0"?>
<data>
<country name="Liechtenstein">
<rank>1</rank>
<year>2008</year>
<gdppc>141100</gdppc>
<neighbor name="Austria" direction="E"/>
<neighbor name="Switzerland" direction="W"/>
</country>
<country name="Singapore">
<rank>4</rank>
<year>2011</year>
<gdppc>59900</gdppc>
<neighbor name="Malaysia" direction="N"/>
</country>
<country name="Panama">
<rank>68</rank>
<year>2011</year>
<gdppc>13600</gdppc>
<neighbor name="Costa Rica" direction="W"/>
<neighbor name="Colombia" direction="E"/>
</country>
</data>

View File

@@ -0,0 +1,2 @@
This file should not be served by StaticFileHandler even though
its name starts with "static".

View File

@@ -0,0 +1,435 @@
#
# Copyright 2014 Facebook
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from contextlib import closing
import os
import socket
import unittest
from tornado.concurrent import Future
from tornado.netutil import bind_sockets, Resolver
from tornado.queues import Queue
from tornado.tcpclient import TCPClient, _Connector
from tornado.tcpserver import TCPServer
from tornado.testing import AsyncTestCase, gen_test
from tornado.test.util import skipIfNoIPv6, refusing_port, skipIfNonUnix
from tornado.gen import TimeoutError
import typing
if typing.TYPE_CHECKING:
from tornado.iostream import IOStream # noqa: F401
from typing import List, Dict, Tuple # noqa: F401
# Fake address families for testing. Used in place of AF_INET
# and AF_INET6 because some installations do not have AF_INET6.
AF1, AF2 = 1, 2
class TestTCPServer(TCPServer):
def __init__(self, family):
super(TestTCPServer, self).__init__()
self.streams = [] # type: List[IOStream]
self.queue = Queue() # type: Queue[IOStream]
sockets = bind_sockets(0, "localhost", family)
self.add_sockets(sockets)
self.port = sockets[0].getsockname()[1]
def handle_stream(self, stream, address):
self.streams.append(stream)
self.queue.put(stream)
def stop(self):
super(TestTCPServer, self).stop()
for stream in self.streams:
stream.close()
class TCPClientTest(AsyncTestCase):
def setUp(self):
super(TCPClientTest, self).setUp()
self.server = None
self.client = TCPClient()
def start_server(self, family):
if family == socket.AF_UNSPEC and "TRAVIS" in os.environ:
self.skipTest("dual-stack servers often have port conflicts on travis")
self.server = TestTCPServer(family)
return self.server.port
def stop_server(self):
if self.server is not None:
self.server.stop()
self.server = None
def tearDown(self):
self.client.close()
self.stop_server()
super(TCPClientTest, self).tearDown()
def skipIfLocalhostV4(self):
# The port used here doesn't matter, but some systems require it
# to be non-zero if we do not also pass AI_PASSIVE.
addrinfo = self.io_loop.run_sync(lambda: Resolver().resolve("localhost", 80))
families = set(addr[0] for addr in addrinfo)
if socket.AF_INET6 not in families:
self.skipTest("localhost does not resolve to ipv6")
@gen_test
def do_test_connect(self, family, host, source_ip=None, source_port=None):
port = self.start_server(family)
stream = yield self.client.connect(
host, port, source_ip=source_ip, source_port=source_port
)
server_stream = yield self.server.queue.get()
with closing(stream):
stream.write(b"hello")
data = yield server_stream.read_bytes(5)
self.assertEqual(data, b"hello")
def test_connect_ipv4_ipv4(self):
self.do_test_connect(socket.AF_INET, "127.0.0.1")
def test_connect_ipv4_dual(self):
self.do_test_connect(socket.AF_INET, "localhost")
@skipIfNoIPv6
def test_connect_ipv6_ipv6(self):
self.skipIfLocalhostV4()
self.do_test_connect(socket.AF_INET6, "::1")
@skipIfNoIPv6
def test_connect_ipv6_dual(self):
self.skipIfLocalhostV4()
if Resolver.configured_class().__name__.endswith("TwistedResolver"):
self.skipTest("TwistedResolver does not support multiple addresses")
self.do_test_connect(socket.AF_INET6, "localhost")
def test_connect_unspec_ipv4(self):
self.do_test_connect(socket.AF_UNSPEC, "127.0.0.1")
@skipIfNoIPv6
def test_connect_unspec_ipv6(self):
self.skipIfLocalhostV4()
self.do_test_connect(socket.AF_UNSPEC, "::1")
def test_connect_unspec_dual(self):
self.do_test_connect(socket.AF_UNSPEC, "localhost")
@gen_test
def test_refused_ipv4(self):
cleanup_func, port = refusing_port()
self.addCleanup(cleanup_func)
with self.assertRaises(IOError):
yield self.client.connect("127.0.0.1", port)
def test_source_ip_fail(self):
"""
Fail when trying to use the source IP Address '8.8.8.8'.
"""
self.assertRaises(
socket.error,
self.do_test_connect,
socket.AF_INET,
"127.0.0.1",
source_ip="8.8.8.8",
)
def test_source_ip_success(self):
"""
Success when trying to use the source IP Address '127.0.0.1'
"""
self.do_test_connect(socket.AF_INET, "127.0.0.1", source_ip="127.0.0.1")
@skipIfNonUnix
def test_source_port_fail(self):
"""
Fail when trying to use source port 1.
"""
self.assertRaises(
socket.error,
self.do_test_connect,
socket.AF_INET,
"127.0.0.1",
source_port=1,
)
@gen_test
def test_connect_timeout(self):
timeout = 0.05
class TimeoutResolver(Resolver):
def resolve(self, *args, **kwargs):
return Future() # never completes
with self.assertRaises(TimeoutError):
yield TCPClient(resolver=TimeoutResolver()).connect(
"1.2.3.4", 12345, timeout=timeout
)
class TestConnectorSplit(unittest.TestCase):
def test_one_family(self):
# These addresses aren't in the right format, but split doesn't care.
primary, secondary = _Connector.split([(AF1, "a"), (AF1, "b")])
self.assertEqual(primary, [(AF1, "a"), (AF1, "b")])
self.assertEqual(secondary, [])
def test_mixed(self):
primary, secondary = _Connector.split(
[(AF1, "a"), (AF2, "b"), (AF1, "c"), (AF2, "d")]
)
self.assertEqual(primary, [(AF1, "a"), (AF1, "c")])
self.assertEqual(secondary, [(AF2, "b"), (AF2, "d")])
class ConnectorTest(AsyncTestCase):
class FakeStream(object):
def __init__(self):
self.closed = False
def close(self):
self.closed = True
def setUp(self):
super(ConnectorTest, self).setUp()
self.connect_futures = (
{}
) # type: Dict[Tuple[int, Tuple], Future[ConnectorTest.FakeStream]]
self.streams = {} # type: Dict[Tuple, ConnectorTest.FakeStream]
self.addrinfo = [(AF1, "a"), (AF1, "b"), (AF2, "c"), (AF2, "d")]
def tearDown(self):
# Unless explicitly checked (and popped) in the test, we shouldn't
# be closing any streams
for stream in self.streams.values():
self.assertFalse(stream.closed)
super(ConnectorTest, self).tearDown()
def create_stream(self, af, addr):
stream = ConnectorTest.FakeStream()
self.streams[addr] = stream
future = Future() # type: Future[ConnectorTest.FakeStream]
self.connect_futures[(af, addr)] = future
return stream, future
def assert_pending(self, *keys):
self.assertEqual(sorted(self.connect_futures.keys()), sorted(keys))
def resolve_connect(self, af, addr, success):
future = self.connect_futures.pop((af, addr))
if success:
future.set_result(self.streams[addr])
else:
self.streams.pop(addr)
future.set_exception(IOError())
# Run the loop to allow callbacks to be run.
self.io_loop.add_callback(self.stop)
self.wait()
def assert_connector_streams_closed(self, conn):
for stream in conn.streams:
self.assertTrue(stream.closed)
def start_connect(self, addrinfo):
conn = _Connector(addrinfo, self.create_stream)
# Give it a huge timeout; we'll trigger timeouts manually.
future = conn.start(3600, connect_timeout=self.io_loop.time() + 3600)
return conn, future
def test_immediate_success(self):
conn, future = self.start_connect(self.addrinfo)
self.assertEqual(list(self.connect_futures.keys()), [(AF1, "a")])
self.resolve_connect(AF1, "a", True)
self.assertEqual(future.result(), (AF1, "a", self.streams["a"]))
def test_immediate_failure(self):
# Fail with just one address.
conn, future = self.start_connect([(AF1, "a")])
self.assert_pending((AF1, "a"))
self.resolve_connect(AF1, "a", False)
self.assertRaises(IOError, future.result)
def test_one_family_second_try(self):
conn, future = self.start_connect([(AF1, "a"), (AF1, "b")])
self.assert_pending((AF1, "a"))
self.resolve_connect(AF1, "a", False)
self.assert_pending((AF1, "b"))
self.resolve_connect(AF1, "b", True)
self.assertEqual(future.result(), (AF1, "b", self.streams["b"]))
def test_one_family_second_try_failure(self):
conn, future = self.start_connect([(AF1, "a"), (AF1, "b")])
self.assert_pending((AF1, "a"))
self.resolve_connect(AF1, "a", False)
self.assert_pending((AF1, "b"))
self.resolve_connect(AF1, "b", False)
self.assertRaises(IOError, future.result)
def test_one_family_second_try_timeout(self):
conn, future = self.start_connect([(AF1, "a"), (AF1, "b")])
self.assert_pending((AF1, "a"))
# trigger the timeout while the first lookup is pending;
# nothing happens.
conn.on_timeout()
self.assert_pending((AF1, "a"))
self.resolve_connect(AF1, "a", False)
self.assert_pending((AF1, "b"))
self.resolve_connect(AF1, "b", True)
self.assertEqual(future.result(), (AF1, "b", self.streams["b"]))
def test_two_families_immediate_failure(self):
conn, future = self.start_connect(self.addrinfo)
self.assert_pending((AF1, "a"))
self.resolve_connect(AF1, "a", False)
self.assert_pending((AF1, "b"), (AF2, "c"))
self.resolve_connect(AF1, "b", False)
self.resolve_connect(AF2, "c", True)
self.assertEqual(future.result(), (AF2, "c", self.streams["c"]))
def test_two_families_timeout(self):
conn, future = self.start_connect(self.addrinfo)
self.assert_pending((AF1, "a"))
conn.on_timeout()
self.assert_pending((AF1, "a"), (AF2, "c"))
self.resolve_connect(AF2, "c", True)
self.assertEqual(future.result(), (AF2, "c", self.streams["c"]))
# resolving 'a' after the connection has completed doesn't start 'b'
self.resolve_connect(AF1, "a", False)
self.assert_pending()
def test_success_after_timeout(self):
conn, future = self.start_connect(self.addrinfo)
self.assert_pending((AF1, "a"))
conn.on_timeout()
self.assert_pending((AF1, "a"), (AF2, "c"))
self.resolve_connect(AF1, "a", True)
self.assertEqual(future.result(), (AF1, "a", self.streams["a"]))
# resolving 'c' after completion closes the connection.
self.resolve_connect(AF2, "c", True)
self.assertTrue(self.streams.pop("c").closed)
def test_all_fail(self):
conn, future = self.start_connect(self.addrinfo)
self.assert_pending((AF1, "a"))
conn.on_timeout()
self.assert_pending((AF1, "a"), (AF2, "c"))
self.resolve_connect(AF2, "c", False)
self.assert_pending((AF1, "a"), (AF2, "d"))
self.resolve_connect(AF2, "d", False)
# one queue is now empty
self.assert_pending((AF1, "a"))
self.resolve_connect(AF1, "a", False)
self.assert_pending((AF1, "b"))
self.assertFalse(future.done())
self.resolve_connect(AF1, "b", False)
self.assertRaises(IOError, future.result)
def test_one_family_timeout_after_connect_timeout(self):
conn, future = self.start_connect([(AF1, "a"), (AF1, "b")])
self.assert_pending((AF1, "a"))
conn.on_connect_timeout()
# the connector will close all streams on connect timeout, we
# should explicitly pop the connect_future.
self.connect_futures.pop((AF1, "a"))
self.assertTrue(self.streams.pop("a").closed)
conn.on_timeout()
# if the future is set with TimeoutError, we will not iterate next
# possible address.
self.assert_pending()
self.assertEqual(len(conn.streams), 1)
self.assert_connector_streams_closed(conn)
self.assertRaises(TimeoutError, future.result)
def test_one_family_success_before_connect_timeout(self):
conn, future = self.start_connect([(AF1, "a"), (AF1, "b")])
self.assert_pending((AF1, "a"))
self.resolve_connect(AF1, "a", True)
conn.on_connect_timeout()
self.assert_pending()
self.assertEqual(self.streams["a"].closed, False)
# success stream will be pop
self.assertEqual(len(conn.streams), 0)
# streams in connector should be closed after connect timeout
self.assert_connector_streams_closed(conn)
self.assertEqual(future.result(), (AF1, "a", self.streams["a"]))
def test_one_family_second_try_after_connect_timeout(self):
conn, future = self.start_connect([(AF1, "a"), (AF1, "b")])
self.assert_pending((AF1, "a"))
self.resolve_connect(AF1, "a", False)
self.assert_pending((AF1, "b"))
conn.on_connect_timeout()
self.connect_futures.pop((AF1, "b"))
self.assertTrue(self.streams.pop("b").closed)
self.assert_pending()
self.assertEqual(len(conn.streams), 2)
self.assert_connector_streams_closed(conn)
self.assertRaises(TimeoutError, future.result)
def test_one_family_second_try_failure_before_connect_timeout(self):
conn, future = self.start_connect([(AF1, "a"), (AF1, "b")])
self.assert_pending((AF1, "a"))
self.resolve_connect(AF1, "a", False)
self.assert_pending((AF1, "b"))
self.resolve_connect(AF1, "b", False)
conn.on_connect_timeout()
self.assert_pending()
self.assertEqual(len(conn.streams), 2)
self.assert_connector_streams_closed(conn)
self.assertRaises(IOError, future.result)
def test_two_family_timeout_before_connect_timeout(self):
conn, future = self.start_connect(self.addrinfo)
self.assert_pending((AF1, "a"))
conn.on_timeout()
self.assert_pending((AF1, "a"), (AF2, "c"))
conn.on_connect_timeout()
self.connect_futures.pop((AF1, "a"))
self.assertTrue(self.streams.pop("a").closed)
self.connect_futures.pop((AF2, "c"))
self.assertTrue(self.streams.pop("c").closed)
self.assert_pending()
self.assertEqual(len(conn.streams), 2)
self.assert_connector_streams_closed(conn)
self.assertRaises(TimeoutError, future.result)
def test_two_family_success_after_timeout(self):
conn, future = self.start_connect(self.addrinfo)
self.assert_pending((AF1, "a"))
conn.on_timeout()
self.assert_pending((AF1, "a"), (AF2, "c"))
self.resolve_connect(AF1, "a", True)
# if one of streams succeed, connector will close all other streams
self.connect_futures.pop((AF2, "c"))
self.assertTrue(self.streams.pop("c").closed)
self.assert_pending()
self.assertEqual(len(conn.streams), 1)
self.assert_connector_streams_closed(conn)
self.assertEqual(future.result(), (AF1, "a", self.streams["a"]))
def test_two_family_timeout_after_connect_timeout(self):
conn, future = self.start_connect(self.addrinfo)
self.assert_pending((AF1, "a"))
conn.on_connect_timeout()
self.connect_futures.pop((AF1, "a"))
self.assertTrue(self.streams.pop("a").closed)
self.assert_pending()
conn.on_timeout()
# if the future is set with TimeoutError, connector will not
# trigger secondary address.
self.assert_pending()
self.assertEqual(len(conn.streams), 1)
self.assert_connector_streams_closed(conn)
self.assertRaises(TimeoutError, future.result)

View File

@@ -0,0 +1,192 @@
import socket
import subprocess
import sys
import textwrap
import unittest
from tornado.escape import utf8, to_unicode
from tornado import gen
from tornado.iostream import IOStream
from tornado.log import app_log
from tornado.tcpserver import TCPServer
from tornado.test.util import skipIfNonUnix
from tornado.testing import AsyncTestCase, ExpectLog, bind_unused_port, gen_test
class TCPServerTest(AsyncTestCase):
@gen_test
def test_handle_stream_coroutine_logging(self):
# handle_stream may be a coroutine and any exception in its
# Future will be logged.
class TestServer(TCPServer):
@gen.coroutine
def handle_stream(self, stream, address):
yield stream.read_bytes(len(b"hello"))
stream.close()
1 / 0
server = client = None
try:
sock, port = bind_unused_port()
server = TestServer()
server.add_socket(sock)
client = IOStream(socket.socket())
with ExpectLog(app_log, "Exception in callback"):
yield client.connect(("localhost", port))
yield client.write(b"hello")
yield client.read_until_close()
yield gen.moment
finally:
if server is not None:
server.stop()
if client is not None:
client.close()
@gen_test
def test_handle_stream_native_coroutine(self):
# handle_stream may be a native coroutine.
class TestServer(TCPServer):
async def handle_stream(self, stream, address):
stream.write(b"data")
stream.close()
sock, port = bind_unused_port()
server = TestServer()
server.add_socket(sock)
client = IOStream(socket.socket())
yield client.connect(("localhost", port))
result = yield client.read_until_close()
self.assertEqual(result, b"data")
server.stop()
client.close()
def test_stop_twice(self):
sock, port = bind_unused_port()
server = TCPServer()
server.add_socket(sock)
server.stop()
server.stop()
@gen_test
def test_stop_in_callback(self):
# Issue #2069: calling server.stop() in a loop callback should not
# raise EBADF when the loop handles other server connection
# requests in the same loop iteration
class TestServer(TCPServer):
@gen.coroutine
def handle_stream(self, stream, address):
server.stop() # type: ignore
yield stream.read_until_close()
sock, port = bind_unused_port()
server = TestServer()
server.add_socket(sock)
server_addr = ("localhost", port)
N = 40
clients = [IOStream(socket.socket()) for i in range(N)]
connected_clients = []
@gen.coroutine
def connect(c):
try:
yield c.connect(server_addr)
except EnvironmentError:
pass
else:
connected_clients.append(c)
yield [connect(c) for c in clients]
self.assertGreater(len(connected_clients), 0, "all clients failed connecting")
try:
if len(connected_clients) == N:
# Ideally we'd make the test deterministic, but we're testing
# for a race condition in combination with the system's TCP stack...
self.skipTest(
"at least one client should fail connecting "
"for the test to be meaningful"
)
finally:
for c in connected_clients:
c.close()
# Here tearDown() would re-raise the EBADF encountered in the IO loop
@skipIfNonUnix
class TestMultiprocess(unittest.TestCase):
# These tests verify that the two multiprocess examples from the
# TCPServer docs work. Both tests start a server with three worker
# processes, each of which prints its task id to stdout (a single
# byte, so we don't have to worry about atomicity of the shared
# stdout stream) and then exits.
def run_subproc(self, code):
proc = subprocess.Popen(
sys.executable, stdin=subprocess.PIPE, stdout=subprocess.PIPE
)
proc.stdin.write(utf8(code))
proc.stdin.close()
proc.wait()
stdout = proc.stdout.read()
proc.stdout.close()
if proc.returncode != 0:
raise RuntimeError(
"Process returned %d. stdout=%r" % (proc.returncode, stdout)
)
return to_unicode(stdout)
def test_single(self):
# As a sanity check, run the single-process version through this test
# harness too.
code = textwrap.dedent(
"""
from tornado.ioloop import IOLoop
from tornado.tcpserver import TCPServer
server = TCPServer()
server.listen(0, address='127.0.0.1')
IOLoop.current().run_sync(lambda: None)
print('012', end='')
"""
)
out = self.run_subproc(code)
self.assertEqual("".join(sorted(out)), "012")
def test_simple(self):
code = textwrap.dedent(
"""
from tornado.ioloop import IOLoop
from tornado.process import task_id
from tornado.tcpserver import TCPServer
server = TCPServer()
server.bind(0, address='127.0.0.1')
server.start(3)
IOLoop.current().run_sync(lambda: None)
print(task_id(), end='')
"""
)
out = self.run_subproc(code)
self.assertEqual("".join(sorted(out)), "012")
def test_advanced(self):
code = textwrap.dedent(
"""
from tornado.ioloop import IOLoop
from tornado.netutil import bind_sockets
from tornado.process import fork_processes, task_id
from tornado.ioloop import IOLoop
from tornado.tcpserver import TCPServer
sockets = bind_sockets(0, address='127.0.0.1')
fork_processes(3)
server = TCPServer()
server.add_sockets(sockets)
IOLoop.current().run_sync(lambda: None)
print(task_id(), end='')
"""
)
out = self.run_subproc(code)
self.assertEqual("".join(sorted(out)), "012")

View File

@@ -0,0 +1,536 @@
import os
import traceback
import unittest
from tornado.escape import utf8, native_str, to_unicode
from tornado.template import Template, DictLoader, ParseError, Loader
from tornado.util import ObjectDict
import typing # noqa: F401
class TemplateTest(unittest.TestCase):
def test_simple(self):
template = Template("Hello {{ name }}!")
self.assertEqual(template.generate(name="Ben"), b"Hello Ben!")
def test_bytes(self):
template = Template("Hello {{ name }}!")
self.assertEqual(template.generate(name=utf8("Ben")), b"Hello Ben!")
def test_expressions(self):
template = Template("2 + 2 = {{ 2 + 2 }}")
self.assertEqual(template.generate(), b"2 + 2 = 4")
def test_comment(self):
template = Template("Hello{# TODO i18n #} {{ name }}!")
self.assertEqual(template.generate(name=utf8("Ben")), b"Hello Ben!")
def test_include(self):
loader = DictLoader(
{
"index.html": '{% include "header.html" %}\nbody text',
"header.html": "header text",
}
)
self.assertEqual(
loader.load("index.html").generate(), b"header text\nbody text"
)
def test_extends(self):
loader = DictLoader(
{
"base.html": """\
<title>{% block title %}default title{% end %}</title>
<body>{% block body %}default body{% end %}</body>
""",
"page.html": """\
{% extends "base.html" %}
{% block title %}page title{% end %}
{% block body %}page body{% end %}
""",
}
)
self.assertEqual(
loader.load("page.html").generate(),
b"<title>page title</title>\n<body>page body</body>\n",
)
def test_relative_load(self):
loader = DictLoader(
{
"a/1.html": "{% include '2.html' %}",
"a/2.html": "{% include '../b/3.html' %}",
"b/3.html": "ok",
}
)
self.assertEqual(loader.load("a/1.html").generate(), b"ok")
def test_escaping(self):
self.assertRaises(ParseError, lambda: Template("{{"))
self.assertRaises(ParseError, lambda: Template("{%"))
self.assertEqual(Template("{{!").generate(), b"{{")
self.assertEqual(Template("{%!").generate(), b"{%")
self.assertEqual(Template("{#!").generate(), b"{#")
self.assertEqual(
Template("{{ 'expr' }} {{!jquery expr}}").generate(),
b"expr {{jquery expr}}",
)
def test_unicode_template(self):
template = Template(utf8(u"\u00e9"))
self.assertEqual(template.generate(), utf8(u"\u00e9"))
def test_unicode_literal_expression(self):
# Unicode literals should be usable in templates. Note that this
# test simulates unicode characters appearing directly in the
# template file (with utf8 encoding), i.e. \u escapes would not
# be used in the template file itself.
template = Template(utf8(u'{{ "\u00e9" }}'))
self.assertEqual(template.generate(), utf8(u"\u00e9"))
def test_custom_namespace(self):
loader = DictLoader(
{"test.html": "{{ inc(5) }}"}, namespace={"inc": lambda x: x + 1}
)
self.assertEqual(loader.load("test.html").generate(), b"6")
def test_apply(self):
def upper(s):
return s.upper()
template = Template(utf8("{% apply upper %}foo{% end %}"))
self.assertEqual(template.generate(upper=upper), b"FOO")
def test_unicode_apply(self):
def upper(s):
return to_unicode(s).upper()
template = Template(utf8(u"{% apply upper %}foo \u00e9{% end %}"))
self.assertEqual(template.generate(upper=upper), utf8(u"FOO \u00c9"))
def test_bytes_apply(self):
def upper(s):
return utf8(to_unicode(s).upper())
template = Template(utf8(u"{% apply upper %}foo \u00e9{% end %}"))
self.assertEqual(template.generate(upper=upper), utf8(u"FOO \u00c9"))
def test_if(self):
template = Template(utf8("{% if x > 4 %}yes{% else %}no{% end %}"))
self.assertEqual(template.generate(x=5), b"yes")
self.assertEqual(template.generate(x=3), b"no")
def test_if_empty_body(self):
template = Template(utf8("{% if True %}{% else %}{% end %}"))
self.assertEqual(template.generate(), b"")
def test_try(self):
template = Template(
utf8(
"""{% try %}
try{% set y = 1/x %}
{% except %}-except
{% else %}-else
{% finally %}-finally
{% end %}"""
)
)
self.assertEqual(template.generate(x=1), b"\ntry\n-else\n-finally\n")
self.assertEqual(template.generate(x=0), b"\ntry-except\n-finally\n")
def test_comment_directive(self):
template = Template(utf8("{% comment blah blah %}foo"))
self.assertEqual(template.generate(), b"foo")
def test_break_continue(self):
template = Template(
utf8(
"""\
{% for i in range(10) %}
{% if i == 2 %}
{% continue %}
{% end %}
{{ i }}
{% if i == 6 %}
{% break %}
{% end %}
{% end %}"""
)
)
result = template.generate()
# remove extraneous whitespace
result = b"".join(result.split())
self.assertEqual(result, b"013456")
def test_break_outside_loop(self):
try:
Template(utf8("{% break %}"))
raise Exception("Did not get expected exception")
except ParseError:
pass
def test_break_in_apply(self):
# This test verifies current behavior, although of course it would
# be nice if apply didn't cause seemingly unrelated breakage
try:
Template(
utf8("{% for i in [] %}{% apply foo %}{% break %}{% end %}{% end %}")
)
raise Exception("Did not get expected exception")
except ParseError:
pass
@unittest.skip("no testable future imports")
def test_no_inherit_future(self):
# TODO(bdarnell): make a test like this for one of the future
# imports available in python 3. Unfortunately they're harder
# to use in a template than division was.
# This file has from __future__ import division...
self.assertEqual(1 / 2, 0.5)
# ...but the template doesn't
template = Template("{{ 1 / 2 }}")
self.assertEqual(template.generate(), "0")
def test_non_ascii_name(self):
loader = DictLoader({u"t\u00e9st.html": "hello"})
self.assertEqual(loader.load(u"t\u00e9st.html").generate(), b"hello")
class StackTraceTest(unittest.TestCase):
def test_error_line_number_expression(self):
loader = DictLoader(
{
"test.html": """one
two{{1/0}}
three
"""
}
)
try:
loader.load("test.html").generate()
self.fail("did not get expected exception")
except ZeroDivisionError:
self.assertTrue("# test.html:2" in traceback.format_exc())
def test_error_line_number_directive(self):
loader = DictLoader(
{
"test.html": """one
two{%if 1/0%}
three{%end%}
"""
}
)
try:
loader.load("test.html").generate()
self.fail("did not get expected exception")
except ZeroDivisionError:
self.assertTrue("# test.html:2" in traceback.format_exc())
def test_error_line_number_module(self):
loader = None # type: typing.Optional[DictLoader]
def load_generate(path, **kwargs):
assert loader is not None
return loader.load(path).generate(**kwargs)
loader = DictLoader(
{"base.html": "{% module Template('sub.html') %}", "sub.html": "{{1/0}}"},
namespace={"_tt_modules": ObjectDict(Template=load_generate)},
)
try:
loader.load("base.html").generate()
self.fail("did not get expected exception")
except ZeroDivisionError:
exc_stack = traceback.format_exc()
self.assertTrue("# base.html:1" in exc_stack)
self.assertTrue("# sub.html:1" in exc_stack)
def test_error_line_number_include(self):
loader = DictLoader(
{"base.html": "{% include 'sub.html' %}", "sub.html": "{{1/0}}"}
)
try:
loader.load("base.html").generate()
self.fail("did not get expected exception")
except ZeroDivisionError:
self.assertTrue("# sub.html:1 (via base.html:1)" in traceback.format_exc())
def test_error_line_number_extends_base_error(self):
loader = DictLoader(
{"base.html": "{{1/0}}", "sub.html": "{% extends 'base.html' %}"}
)
try:
loader.load("sub.html").generate()
self.fail("did not get expected exception")
except ZeroDivisionError:
exc_stack = traceback.format_exc()
self.assertTrue("# base.html:1" in exc_stack)
def test_error_line_number_extends_sub_error(self):
loader = DictLoader(
{
"base.html": "{% block 'block' %}{% end %}",
"sub.html": """
{% extends 'base.html' %}
{% block 'block' %}
{{1/0}}
{% end %}
""",
}
)
try:
loader.load("sub.html").generate()
self.fail("did not get expected exception")
except ZeroDivisionError:
self.assertTrue("# sub.html:4 (via base.html:1)" in traceback.format_exc())
def test_multi_includes(self):
loader = DictLoader(
{
"a.html": "{% include 'b.html' %}",
"b.html": "{% include 'c.html' %}",
"c.html": "{{1/0}}",
}
)
try:
loader.load("a.html").generate()
self.fail("did not get expected exception")
except ZeroDivisionError:
self.assertTrue(
"# c.html:1 (via b.html:1, a.html:1)" in traceback.format_exc()
)
class ParseErrorDetailTest(unittest.TestCase):
def test_details(self):
loader = DictLoader({"foo.html": "\n\n{{"})
with self.assertRaises(ParseError) as cm:
loader.load("foo.html")
self.assertEqual("Missing end expression }} at foo.html:3", str(cm.exception))
self.assertEqual("foo.html", cm.exception.filename)
self.assertEqual(3, cm.exception.lineno)
def test_custom_parse_error(self):
# Make sure that ParseErrors remain compatible with their
# pre-4.3 signature.
self.assertEqual("asdf at None:0", str(ParseError("asdf")))
class AutoEscapeTest(unittest.TestCase):
def setUp(self):
self.templates = {
"escaped.html": "{% autoescape xhtml_escape %}{{ name }}",
"unescaped.html": "{% autoescape None %}{{ name }}",
"default.html": "{{ name }}",
"include.html": """\
escaped: {% include 'escaped.html' %}
unescaped: {% include 'unescaped.html' %}
default: {% include 'default.html' %}
""",
"escaped_block.html": """\
{% autoescape xhtml_escape %}\
{% block name %}base: {{ name }}{% end %}""",
"unescaped_block.html": """\
{% autoescape None %}\
{% block name %}base: {{ name }}{% end %}""",
# Extend a base template with different autoescape policy,
# with and without overriding the base's blocks
"escaped_extends_unescaped.html": """\
{% autoescape xhtml_escape %}\
{% extends "unescaped_block.html" %}""",
"escaped_overrides_unescaped.html": """\
{% autoescape xhtml_escape %}\
{% extends "unescaped_block.html" %}\
{% block name %}extended: {{ name }}{% end %}""",
"unescaped_extends_escaped.html": """\
{% autoescape None %}\
{% extends "escaped_block.html" %}""",
"unescaped_overrides_escaped.html": """\
{% autoescape None %}\
{% extends "escaped_block.html" %}\
{% block name %}extended: {{ name }}{% end %}""",
"raw_expression.html": """\
{% autoescape xhtml_escape %}\
expr: {{ name }}
raw: {% raw name %}""",
}
def test_default_off(self):
loader = DictLoader(self.templates, autoescape=None)
name = "Bobby <table>s"
self.assertEqual(
loader.load("escaped.html").generate(name=name), b"Bobby &lt;table&gt;s"
)
self.assertEqual(
loader.load("unescaped.html").generate(name=name), b"Bobby <table>s"
)
self.assertEqual(
loader.load("default.html").generate(name=name), b"Bobby <table>s"
)
self.assertEqual(
loader.load("include.html").generate(name=name),
b"escaped: Bobby &lt;table&gt;s\n"
b"unescaped: Bobby <table>s\n"
b"default: Bobby <table>s\n",
)
def test_default_on(self):
loader = DictLoader(self.templates, autoescape="xhtml_escape")
name = "Bobby <table>s"
self.assertEqual(
loader.load("escaped.html").generate(name=name), b"Bobby &lt;table&gt;s"
)
self.assertEqual(
loader.load("unescaped.html").generate(name=name), b"Bobby <table>s"
)
self.assertEqual(
loader.load("default.html").generate(name=name), b"Bobby &lt;table&gt;s"
)
self.assertEqual(
loader.load("include.html").generate(name=name),
b"escaped: Bobby &lt;table&gt;s\n"
b"unescaped: Bobby <table>s\n"
b"default: Bobby &lt;table&gt;s\n",
)
def test_unextended_block(self):
loader = DictLoader(self.templates)
name = "<script>"
self.assertEqual(
loader.load("escaped_block.html").generate(name=name),
b"base: &lt;script&gt;",
)
self.assertEqual(
loader.load("unescaped_block.html").generate(name=name), b"base: <script>"
)
def test_extended_block(self):
loader = DictLoader(self.templates)
def render(name):
return loader.load(name).generate(name="<script>")
self.assertEqual(render("escaped_extends_unescaped.html"), b"base: <script>")
self.assertEqual(
render("escaped_overrides_unescaped.html"), b"extended: &lt;script&gt;"
)
self.assertEqual(
render("unescaped_extends_escaped.html"), b"base: &lt;script&gt;"
)
self.assertEqual(
render("unescaped_overrides_escaped.html"), b"extended: <script>"
)
def test_raw_expression(self):
loader = DictLoader(self.templates)
def render(name):
return loader.load(name).generate(name='<>&"')
self.assertEqual(
render("raw_expression.html"), b"expr: &lt;&gt;&amp;&quot;\n" b'raw: <>&"'
)
def test_custom_escape(self):
loader = DictLoader({"foo.py": "{% autoescape py_escape %}s = {{ name }}\n"})
def py_escape(s):
self.assertEqual(type(s), bytes)
return repr(native_str(s))
def render(template, name):
return loader.load(template).generate(py_escape=py_escape, name=name)
self.assertEqual(render("foo.py", "<html>"), b"s = '<html>'\n")
self.assertEqual(render("foo.py", "';sys.exit()"), b"""s = "';sys.exit()"\n""")
self.assertEqual(
render("foo.py", ["not a string"]), b"""s = "['not a string']"\n"""
)
def test_manual_minimize_whitespace(self):
# Whitespace including newlines is allowed within template tags
# and directives, and this is one way to avoid long lines while
# keeping extra whitespace out of the rendered output.
loader = DictLoader(
{
"foo.txt": """\
{% for i in items
%}{% if i > 0 %}, {% end %}{#
#}{{i
}}{% end
%}"""
}
)
self.assertEqual(
loader.load("foo.txt").generate(items=range(5)), b"0, 1, 2, 3, 4"
)
def test_whitespace_by_filename(self):
# Default whitespace handling depends on the template filename.
loader = DictLoader(
{
"foo.html": " \n\t\n asdf\t ",
"bar.js": " \n\n\n\t qwer ",
"baz.txt": "\t zxcv\n\n",
"include.html": " {% include baz.txt %} \n ",
"include.txt": "\t\t{% include foo.html %} ",
}
)
# HTML and JS files have whitespace compressed by default.
self.assertEqual(loader.load("foo.html").generate(), b"\nasdf ")
self.assertEqual(loader.load("bar.js").generate(), b"\nqwer ")
# TXT files do not.
self.assertEqual(loader.load("baz.txt").generate(), b"\t zxcv\n\n")
# Each file maintains its own status even when included in
# a file of the other type.
self.assertEqual(loader.load("include.html").generate(), b" \t zxcv\n\n\n")
self.assertEqual(loader.load("include.txt").generate(), b"\t\t\nasdf ")
def test_whitespace_by_loader(self):
templates = {"foo.html": "\t\tfoo\n\n", "bar.txt": "\t\tbar\n\n"}
loader = DictLoader(templates, whitespace="all")
self.assertEqual(loader.load("foo.html").generate(), b"\t\tfoo\n\n")
self.assertEqual(loader.load("bar.txt").generate(), b"\t\tbar\n\n")
loader = DictLoader(templates, whitespace="single")
self.assertEqual(loader.load("foo.html").generate(), b" foo\n")
self.assertEqual(loader.load("bar.txt").generate(), b" bar\n")
loader = DictLoader(templates, whitespace="oneline")
self.assertEqual(loader.load("foo.html").generate(), b" foo ")
self.assertEqual(loader.load("bar.txt").generate(), b" bar ")
def test_whitespace_directive(self):
loader = DictLoader(
{
"foo.html": """\
{% whitespace oneline %}
{% for i in range(3) %}
{{ i }}
{% end %}
{% whitespace all %}
pre\tformatted
"""
}
)
self.assertEqual(
loader.load("foo.html").generate(), b" 0 1 2 \n pre\tformatted\n"
)
class TemplateLoaderTest(unittest.TestCase):
def setUp(self):
self.loader = Loader(os.path.join(os.path.dirname(__file__), "templates"))
def test_utf8_in_file(self):
tmpl = self.loader.load("utf8.html")
result = tmpl.generate()
self.assertEqual(to_unicode(result).strip(), u"H\u00e9llo")

View File

@@ -0,0 +1 @@
Héllo

View File

@@ -0,0 +1,20 @@
-----BEGIN CERTIFICATE-----
MIIDWzCCAkOgAwIBAgIUV4spou0CenmvKqa7Hml/MC+JKiAwDQYJKoZIhvcNAQEL
BQAwPTELMAkGA1UEBhMCVVMxEzARBgNVBAgMCkNhbGlmb3JuaWExGTAXBgNVBAoM
EFRvcm5hZG8gV2ViIFRlc3QwHhcNMTgwOTI5MTM1NjQ1WhcNMjgwOTI2MTM1NjQ1
WjA9MQswCQYDVQQGEwJVUzETMBEGA1UECAwKQ2FsaWZvcm5pYTEZMBcGA1UECgwQ
VG9ybmFkbyBXZWIgVGVzdDCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEB
AKT0LdyI8tW5uwP3ahE8BFSz+j3SsKBDv/0cKvqxVVE6sLEST2s3HjArZvIIG5sb
iBkWDrqnZ6UKDvB4jlobLGAkepxDbrxHWxK53n0C28XXGLqJQ01TlTZ5rpjttMeg
5SKNjHbxpOvpUwwQS4br4WjZKKyTGiXpFkFUty+tYVU35/U2yyvreWHmzpHx/25t
H7O2RBARVwJYKOGPtlH62lQjpIWfVfklY4Ip8Hjl3B6rBxPyBULmVQw0qgoZn648
oa4oLjs0wnYBz01gVjNMDHej52SsB/ieH7W1TxFMzqOlcvHh41uFbQJPgcXsruSS
9Z4twzSWkUp2vk/C//4Sz38CAwEAAaNTMFEwHQYDVR0OBBYEFLf8fQ5+u8sDWAd3
r5ZjZ5MmDWJeMB8GA1UdIwQYMBaAFLf8fQ5+u8sDWAd3r5ZjZ5MmDWJeMA8GA1Ud
EwEB/wQFMAMBAf8wDQYJKoZIhvcNAQELBQADggEBADkkm3pIb9IeqVNmQ2uhQOgw
UwyToTYUHNTb/Nm5lzBTBqC8gbXAS24RQ30AB/7G115Uxeo+YMKfITxm/CgR+vhF
F59/YrzwXj+G8bdbuVl/UbB6f9RSp+Zo93rUZAtPWr77gxLUrcwSRzzDwxFjC2nC
6eigbkvt1OQY775RwnFAt7HKPclE0Out+cGJIboJuO1f3r57ZdyFH0GzbZEff/7K
atGXohijWJjYvU4mk0KFHORZrcBpsv9cfkFbmgVmiRwxRJ1tLauHM3Ne+VfqYE5M
4rTStSyz3ASqVKJ2iFMQueNR/tUOuDlfRt+0nhJMuYSSkW+KTgnwyOGU9cv+mxA=
-----END CERTIFICATE-----

View File

@@ -0,0 +1,28 @@
-----BEGIN PRIVATE KEY-----
MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQCk9C3ciPLVubsD
92oRPARUs/o90rCgQ7/9HCr6sVVROrCxEk9rNx4wK2byCBubG4gZFg66p2elCg7w
eI5aGyxgJHqcQ268R1sSud59AtvF1xi6iUNNU5U2ea6Y7bTHoOUijYx28aTr6VMM
EEuG6+Fo2Siskxol6RZBVLcvrWFVN+f1Nssr63lh5s6R8f9ubR+ztkQQEVcCWCjh
j7ZR+tpUI6SFn1X5JWOCKfB45dweqwcT8gVC5lUMNKoKGZ+uPKGuKC47NMJ2Ac9N
YFYzTAx3o+dkrAf4nh+1tU8RTM6jpXLx4eNbhW0CT4HF7K7kkvWeLcM0lpFKdr5P
wv/+Es9/AgMBAAECggEABi6AaXtYXloPgB6NgwfUwbfc8OQsalUfpMShd7OdluW0
KW6eO05de0ClIvzay/1EJGyHMMeFQtIVrT1XWFkcWJ4FWkXMqJGkABenFtg8lDVz
X8o1E3jGZrw4ptKBq9mDvL/BO9PiclTUH+ecbPn6AIvi0lTQ7grGIryiAM9mjmLy
jpCwoutF2LD4RPNg8vqWe/Z1rQw5lp8FOHhRwPooHHeoq1bSrp8dqvVAwAam7Mmf
uFgI8jrNycPgr2cwEEtbq2TQ625MhVnCpwT+kErmAStfbXXuqv1X1ZZgiNxf+61C
OL0bhPRVIHmmjiK/5qHRuN4Q5u9/Yp2SJ4W5xadSQQKBgQDR7dnOlYYQiaoPJeD/
7jcLVJbWwbr7bE19O/QpYAtkA/FtGlKr+hQxPhK6OYp+in8eHf+ga/NSAjCWRBoh
MNAVCJtiirHo2tFsLFOmlJpGL9n3sX8UnkJN90oHfWrzJ8BZnXaSw2eOuyw8LLj+
Q+ISl6Go8/xfsuy3EDv4AP1wCwKBgQDJJ4vEV3Kr+bc6N/xeu+G0oHvRAWwuQpcx
9D+XpnqbJbFDnWKNE7oGsDCs8Qjr0CdFUN1pm1ppITDZ5N1cWuDg/47ZAXqEK6D1
z13S7O0oQPlnsPL7mHs2Vl73muAaBPAojFvceHHfccr7Z94BXqKsiyfaWz6kclT/
Nl4JTdsC3QKBgQCeYgozL2J/da2lUhnIXcyPstk+29kbueFYu/QBh2HwqnzqqLJ4
5+t2H3P3plQUFp/DdDSZrvhcBiTsKiNgqThEtkKtfSCvIvBf4a2W/4TJsW6MzxCm
2KQDuK/UqM4Y+APKWN/N6Lln2VWNbNyBkWuuRVKFatccyJyJnSjxeqW7cwKBgGyN
idCYPIrwROAHLItXKvOWE5t0ABRq3TsZC2RkdA/b5HCPs4pclexcEriRjvXrK/Yt
MH94Ve8b+UftSUQ4ytjBMS6MrLg87y0YDhLwxv8NKUq65DXAUOW+8JsAmmWQOqY3
MK+m1BT4TMklgVoN3w3sPsKIsSJ/jLz5cv/kYweFAoGAG4iWU1378tI2Ts/Fngsv
7eoWhoda77Y9D0Yoy20aN9VdMHzIYCBOubtRPEuwgaReNwbUBWap01J63yY/fF3K
8PTz6covjoOJqxQJOvM7nM0CsJawG9ccw3YXyd9KgRIdSt6ooEhb7N8W2EXYoKl3
g1i2t41Q/SC3HUGC5mJjpO8=
-----END PRIVATE KEY-----

View File

@@ -0,0 +1,350 @@
from tornado import gen, ioloop
from tornado.httpserver import HTTPServer
from tornado.locks import Event
from tornado.testing import AsyncHTTPTestCase, AsyncTestCase, bind_unused_port, gen_test
from tornado.web import Application
import asyncio
import contextlib
import gc
import os
import platform
import traceback
import unittest
import warnings
@contextlib.contextmanager
def set_environ(name, value):
old_value = os.environ.get(name)
os.environ[name] = value
try:
yield
finally:
if old_value is None:
del os.environ[name]
else:
os.environ[name] = old_value
class AsyncTestCaseTest(AsyncTestCase):
def test_wait_timeout(self):
time = self.io_loop.time
# Accept default 5-second timeout, no error
self.io_loop.add_timeout(time() + 0.01, self.stop)
self.wait()
# Timeout passed to wait()
self.io_loop.add_timeout(time() + 1, self.stop)
with self.assertRaises(self.failureException):
self.wait(timeout=0.01)
# Timeout set with environment variable
self.io_loop.add_timeout(time() + 1, self.stop)
with set_environ("ASYNC_TEST_TIMEOUT", "0.01"):
with self.assertRaises(self.failureException):
self.wait()
def test_subsequent_wait_calls(self):
"""
This test makes sure that a second call to wait()
clears the first timeout.
"""
self.io_loop.add_timeout(self.io_loop.time() + 0.00, self.stop)
self.wait(timeout=0.02)
self.io_loop.add_timeout(self.io_loop.time() + 0.03, self.stop)
self.wait(timeout=0.15)
class LeakTest(AsyncTestCase):
def tearDown(self):
super().tearDown()
# Trigger a gc to make warnings more deterministic.
gc.collect()
def test_leaked_coroutine(self):
# This test verifies that "leaked" coroutines are shut down
# without triggering warnings like "task was destroyed but it
# is pending". If this test were to fail, it would fail
# because runtests.py detected unexpected output to stderr.
event = Event()
async def callback():
try:
await event.wait()
except asyncio.CancelledError:
pass
self.io_loop.add_callback(callback)
self.io_loop.add_callback(self.stop)
self.wait()
class AsyncHTTPTestCaseTest(AsyncHTTPTestCase):
def setUp(self):
super(AsyncHTTPTestCaseTest, self).setUp()
# Bind a second port.
sock, port = bind_unused_port()
app = Application()
server = HTTPServer(app, **self.get_httpserver_options())
server.add_socket(sock)
self.second_port = port
self.second_server = server
def get_app(self):
return Application()
def test_fetch_segment(self):
path = "/path"
response = self.fetch(path)
self.assertEqual(response.request.url, self.get_url(path))
def test_fetch_full_http_url(self):
# Ensure that self.fetch() recognizes absolute urls and does
# not transform them into references to our main test server.
path = "http://localhost:%d/path" % self.second_port
response = self.fetch(path)
self.assertEqual(response.request.url, path)
def tearDown(self):
self.second_server.stop()
super(AsyncHTTPTestCaseTest, self).tearDown()
class AsyncTestCaseWrapperTest(unittest.TestCase):
def test_undecorated_generator(self):
class Test(AsyncTestCase):
def test_gen(self):
yield
test = Test("test_gen")
result = unittest.TestResult()
test.run(result)
self.assertEqual(len(result.errors), 1)
self.assertIn("should be decorated", result.errors[0][1])
@unittest.skipIf(
platform.python_implementation() == "PyPy",
"pypy destructor warnings cannot be silenced",
)
def test_undecorated_coroutine(self):
class Test(AsyncTestCase):
async def test_coro(self):
pass
test = Test("test_coro")
result = unittest.TestResult()
# Silence "RuntimeWarning: coroutine 'test_coro' was never awaited".
with warnings.catch_warnings():
warnings.simplefilter("ignore")
test.run(result)
self.assertEqual(len(result.errors), 1)
self.assertIn("should be decorated", result.errors[0][1])
def test_undecorated_generator_with_skip(self):
class Test(AsyncTestCase):
@unittest.skip("don't run this")
def test_gen(self):
yield
test = Test("test_gen")
result = unittest.TestResult()
test.run(result)
self.assertEqual(len(result.errors), 0)
self.assertEqual(len(result.skipped), 1)
def test_other_return(self):
class Test(AsyncTestCase):
def test_other_return(self):
return 42
test = Test("test_other_return")
result = unittest.TestResult()
test.run(result)
self.assertEqual(len(result.errors), 1)
self.assertIn("Return value from test method ignored", result.errors[0][1])
class SetUpTearDownTest(unittest.TestCase):
def test_set_up_tear_down(self):
"""
This test makes sure that AsyncTestCase calls super methods for
setUp and tearDown.
InheritBoth is a subclass of both AsyncTestCase and
SetUpTearDown, with the ordering so that the super of
AsyncTestCase will be SetUpTearDown.
"""
events = []
result = unittest.TestResult()
class SetUpTearDown(unittest.TestCase):
def setUp(self):
events.append("setUp")
def tearDown(self):
events.append("tearDown")
class InheritBoth(AsyncTestCase, SetUpTearDown):
def test(self):
events.append("test")
InheritBoth("test").run(result)
expected = ["setUp", "test", "tearDown"]
self.assertEqual(expected, events)
class AsyncHTTPTestCaseSetUpTearDownTest(unittest.TestCase):
def test_tear_down_releases_app_and_http_server(self):
result = unittest.TestResult()
class SetUpTearDown(AsyncHTTPTestCase):
def get_app(self):
return Application()
def test(self):
self.assertTrue(hasattr(self, "_app"))
self.assertTrue(hasattr(self, "http_server"))
test = SetUpTearDown("test")
test.run(result)
self.assertFalse(hasattr(test, "_app"))
self.assertFalse(hasattr(test, "http_server"))
class GenTest(AsyncTestCase):
def setUp(self):
super(GenTest, self).setUp()
self.finished = False
def tearDown(self):
self.assertTrue(self.finished)
super(GenTest, self).tearDown()
@gen_test
def test_sync(self):
self.finished = True
@gen_test
def test_async(self):
yield gen.moment
self.finished = True
def test_timeout(self):
# Set a short timeout and exceed it.
@gen_test(timeout=0.1)
def test(self):
yield gen.sleep(1)
# This can't use assertRaises because we need to inspect the
# exc_info triple (and not just the exception object)
try:
test(self)
self.fail("did not get expected exception")
except ioloop.TimeoutError:
# The stack trace should blame the add_timeout line, not just
# unrelated IOLoop/testing internals.
self.assertIn("gen.sleep(1)", traceback.format_exc())
self.finished = True
def test_no_timeout(self):
# A test that does not exceed its timeout should succeed.
@gen_test(timeout=1)
def test(self):
yield gen.sleep(0.1)
test(self)
self.finished = True
def test_timeout_environment_variable(self):
@gen_test(timeout=0.5)
def test_long_timeout(self):
yield gen.sleep(0.25)
# Uses provided timeout of 0.5 seconds, doesn't time out.
with set_environ("ASYNC_TEST_TIMEOUT", "0.1"):
test_long_timeout(self)
self.finished = True
def test_no_timeout_environment_variable(self):
@gen_test(timeout=0.01)
def test_short_timeout(self):
yield gen.sleep(1)
# Uses environment-variable timeout of 0.1, times out.
with set_environ("ASYNC_TEST_TIMEOUT", "0.1"):
with self.assertRaises(ioloop.TimeoutError):
test_short_timeout(self)
self.finished = True
def test_with_method_args(self):
@gen_test
def test_with_args(self, *args):
self.assertEqual(args, ("test",))
yield gen.moment
test_with_args(self, "test")
self.finished = True
def test_with_method_kwargs(self):
@gen_test
def test_with_kwargs(self, **kwargs):
self.assertDictEqual(kwargs, {"test": "test"})
yield gen.moment
test_with_kwargs(self, test="test")
self.finished = True
def test_native_coroutine(self):
@gen_test
async def test(self):
self.finished = True
test(self)
def test_native_coroutine_timeout(self):
# Set a short timeout and exceed it.
@gen_test(timeout=0.1)
async def test(self):
await gen.sleep(1)
try:
test(self)
self.fail("did not get expected exception")
except ioloop.TimeoutError:
self.finished = True
class GetNewIOLoopTest(AsyncTestCase):
def get_new_ioloop(self):
# Use the current loop instead of creating a new one here.
return ioloop.IOLoop.current()
def setUp(self):
# This simulates the effect of an asyncio test harness like
# pytest-asyncio.
self.orig_loop = asyncio.get_event_loop()
self.new_loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.new_loop)
super(GetNewIOLoopTest, self).setUp()
def tearDown(self):
super(GetNewIOLoopTest, self).tearDown()
# AsyncTestCase must not affect the existing asyncio loop.
self.assertFalse(asyncio.get_event_loop().is_closed())
asyncio.set_event_loop(self.orig_loop)
self.new_loop.close()
def test_loop(self):
self.assertIs(self.io_loop.asyncio_loop, self.new_loop)
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,236 @@
# Author: Ovidiu Predescu
# Date: July 2011
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import asyncio
import logging
import signal
import unittest
import warnings
from tornado.escape import utf8
from tornado import gen
from tornado.httpclient import AsyncHTTPClient
from tornado.httpserver import HTTPServer
from tornado.ioloop import IOLoop
from tornado.testing import bind_unused_port, AsyncTestCase, gen_test
from tornado.web import RequestHandler, Application
try:
from twisted.internet.defer import ( # type: ignore
Deferred,
inlineCallbacks,
returnValue,
)
from twisted.internet.protocol import Protocol # type: ignore
from twisted.internet.asyncioreactor import AsyncioSelectorReactor # type: ignore
from twisted.web.client import Agent, readBody # type: ignore
from twisted.web.resource import Resource # type: ignore
from twisted.web.server import Site # type: ignore
have_twisted = True
except ImportError:
have_twisted = False
else:
# Not used directly but needed for `yield deferred` to work.
import tornado.platform.twisted # noqa: F401
skipIfNoTwisted = unittest.skipUnless(have_twisted, "twisted module not present")
def save_signal_handlers():
saved = {}
for sig in [signal.SIGINT, signal.SIGTERM, signal.SIGCHLD]:
saved[sig] = signal.getsignal(sig)
if "twisted" in repr(saved):
# This indicates we're not cleaning up after ourselves properly.
raise Exception("twisted signal handlers already installed")
return saved
def restore_signal_handlers(saved):
for sig, handler in saved.items():
signal.signal(sig, handler)
# Test various combinations of twisted and tornado http servers,
# http clients, and event loop interfaces.
@skipIfNoTwisted
class CompatibilityTests(unittest.TestCase):
def setUp(self):
self.saved_signals = save_signal_handlers()
self.io_loop = IOLoop()
self.io_loop.make_current()
self.reactor = AsyncioSelectorReactor()
def tearDown(self):
self.reactor.disconnectAll()
self.io_loop.clear_current()
self.io_loop.close(all_fds=True)
restore_signal_handlers(self.saved_signals)
def start_twisted_server(self):
class HelloResource(Resource):
isLeaf = True
def render_GET(self, request):
return b"Hello from twisted!"
site = Site(HelloResource())
port = self.reactor.listenTCP(0, site, interface="127.0.0.1")
self.twisted_port = port.getHost().port
def start_tornado_server(self):
class HelloHandler(RequestHandler):
def get(self):
self.write("Hello from tornado!")
app = Application([("/", HelloHandler)], log_function=lambda x: None)
server = HTTPServer(app)
sock, self.tornado_port = bind_unused_port()
server.add_sockets([sock])
def run_reactor(self):
# In theory, we can run the event loop through Tornado,
# Twisted, or asyncio interfaces. However, since we're trying
# to avoid installing anything as the global event loop, only
# the twisted interface gets everything wired up correectly
# without extra hacks. This method is a part of a
# no-longer-used generalization that allowed us to test
# different combinations.
self.stop_loop = self.reactor.stop
self.stop = self.reactor.stop
self.reactor.run()
def tornado_fetch(self, url, runner):
client = AsyncHTTPClient()
fut = asyncio.ensure_future(client.fetch(url))
fut.add_done_callback(lambda f: self.stop_loop())
runner()
return fut.result()
def twisted_fetch(self, url, runner):
# http://twistedmatrix.com/documents/current/web/howto/client.html
chunks = []
client = Agent(self.reactor)
d = client.request(b"GET", utf8(url))
class Accumulator(Protocol):
def __init__(self, finished):
self.finished = finished
def dataReceived(self, data):
chunks.append(data)
def connectionLost(self, reason):
self.finished.callback(None)
def callback(response):
finished = Deferred()
response.deliverBody(Accumulator(finished))
return finished
d.addCallback(callback)
def shutdown(failure):
if hasattr(self, "stop_loop"):
self.stop_loop()
elif failure is not None:
# loop hasn't been initialized yet; try our best to
# get an error message out. (the runner() interaction
# should probably be refactored).
try:
failure.raiseException()
except:
logging.error("exception before starting loop", exc_info=True)
d.addBoth(shutdown)
runner()
self.assertTrue(chunks)
return b"".join(chunks)
def twisted_coroutine_fetch(self, url, runner):
body = [None]
@gen.coroutine
def f():
# This is simpler than the non-coroutine version, but it cheats
# by reading the body in one blob instead of streaming it with
# a Protocol.
client = Agent(self.reactor)
response = yield client.request(b"GET", utf8(url))
with warnings.catch_warnings():
# readBody has a buggy DeprecationWarning in Twisted 15.0:
# https://twistedmatrix.com/trac/changeset/43379
warnings.simplefilter("ignore", category=DeprecationWarning)
body[0] = yield readBody(response)
self.stop_loop()
self.io_loop.add_callback(f)
runner()
return body[0]
def testTwistedServerTornadoClientReactor(self):
self.start_twisted_server()
response = self.tornado_fetch(
"http://127.0.0.1:%d" % self.twisted_port, self.run_reactor
)
self.assertEqual(response.body, b"Hello from twisted!")
def testTornadoServerTwistedClientReactor(self):
self.start_tornado_server()
response = self.twisted_fetch(
"http://127.0.0.1:%d" % self.tornado_port, self.run_reactor
)
self.assertEqual(response, b"Hello from tornado!")
def testTornadoServerTwistedCoroutineClientReactor(self):
self.start_tornado_server()
response = self.twisted_coroutine_fetch(
"http://127.0.0.1:%d" % self.tornado_port, self.run_reactor
)
self.assertEqual(response, b"Hello from tornado!")
@skipIfNoTwisted
class ConvertDeferredTest(AsyncTestCase):
@gen_test
def test_success(self):
@inlineCallbacks
def fn():
if False:
# inlineCallbacks doesn't work with regular functions;
# must have a yield even if it's unreachable.
yield
returnValue(42)
res = yield fn()
self.assertEqual(res, 42)
@gen_test
def test_failure(self):
@inlineCallbacks
def fn():
if False:
yield
1 / 0
with self.assertRaises(ZeroDivisionError):
yield fn()
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,114 @@
import contextlib
import os
import platform
import socket
import sys
import textwrap
import typing # noqa: F401
import unittest
import warnings
from tornado.testing import bind_unused_port
skipIfNonUnix = unittest.skipIf(
os.name != "posix" or sys.platform == "cygwin", "non-unix platform"
)
# travis-ci.org runs our tests in an overworked virtual machine, which makes
# timing-related tests unreliable.
skipOnTravis = unittest.skipIf(
"TRAVIS" in os.environ, "timing tests unreliable on travis"
)
# Set the environment variable NO_NETWORK=1 to disable any tests that
# depend on an external network.
skipIfNoNetwork = unittest.skipIf("NO_NETWORK" in os.environ, "network access disabled")
skipNotCPython = unittest.skipIf(
platform.python_implementation() != "CPython", "Not CPython implementation"
)
# Used for tests affected by
# https://bitbucket.org/pypy/pypy/issues/2616/incomplete-error-handling-in
# TODO: remove this after pypy3 5.8 is obsolete.
skipPypy3V58 = unittest.skipIf(
platform.python_implementation() == "PyPy"
and sys.version_info > (3,)
and sys.pypy_version_info < (5, 9), # type: ignore
"pypy3 5.8 has buggy ssl module",
)
def _detect_ipv6():
if not socket.has_ipv6:
# socket.has_ipv6 check reports whether ipv6 was present at compile
# time. It's usually true even when ipv6 doesn't work for other reasons.
return False
sock = None
try:
sock = socket.socket(socket.AF_INET6)
sock.bind(("::1", 0))
except socket.error:
return False
finally:
if sock is not None:
sock.close()
return True
skipIfNoIPv6 = unittest.skipIf(not _detect_ipv6(), "ipv6 support not present")
def refusing_port():
"""Returns a local port number that will refuse all connections.
Return value is (cleanup_func, port); the cleanup function
must be called to free the port to be reused.
"""
# On travis-ci, port numbers are reassigned frequently. To avoid
# collisions with other tests, we use an open client-side socket's
# ephemeral port number to ensure that nothing can listen on that
# port.
server_socket, port = bind_unused_port()
server_socket.setblocking(True)
client_socket = socket.socket()
client_socket.connect(("127.0.0.1", port))
conn, client_addr = server_socket.accept()
conn.close()
server_socket.close()
return (client_socket.close, client_addr[1])
def exec_test(caller_globals, caller_locals, s):
"""Execute ``s`` in a given context and return the result namespace.
Used to define functions for tests in particular python
versions that would be syntax errors in older versions.
"""
# Flatten the real global and local namespace into our fake
# globals: it's all global from the perspective of code defined
# in s.
global_namespace = dict(caller_globals, **caller_locals) # type: ignore
local_namespace = {} # type: typing.Dict[str, typing.Any]
exec(textwrap.dedent(s), global_namespace, local_namespace)
return local_namespace
def subTest(test, *args, **kwargs):
"""Compatibility shim for unittest.TestCase.subTest.
Usage: ``with tornado.test.util.subTest(self, x=x):``
"""
try:
subTest = test.subTest # py34+
except AttributeError:
subTest = contextlib.contextmanager(lambda *a, **kw: (yield))
return subTest(*args, **kwargs)
@contextlib.contextmanager
def ignore_deprecation():
"""Context manager to ignore deprecation warnings."""
with warnings.catch_warnings():
warnings.simplefilter("ignore", DeprecationWarning)
yield

View File

@@ -0,0 +1,309 @@
# coding: utf-8
from io import StringIO
import re
import sys
import datetime
import unittest
import tornado.escape
from tornado.escape import utf8
from tornado.util import (
raise_exc_info,
Configurable,
exec_in,
ArgReplacer,
timedelta_to_seconds,
import_object,
re_unescape,
is_finalizing,
)
import typing
from typing import cast
if typing.TYPE_CHECKING:
from typing import Dict, Any # noqa: F401
class RaiseExcInfoTest(unittest.TestCase):
def test_two_arg_exception(self):
# This test would fail on python 3 if raise_exc_info were simply
# a three-argument raise statement, because TwoArgException
# doesn't have a "copy constructor"
class TwoArgException(Exception):
def __init__(self, a, b):
super(TwoArgException, self).__init__()
self.a, self.b = a, b
try:
raise TwoArgException(1, 2)
except TwoArgException:
exc_info = sys.exc_info()
try:
raise_exc_info(exc_info)
self.fail("didn't get expected exception")
except TwoArgException as e:
self.assertIs(e, exc_info[1])
class TestConfigurable(Configurable):
@classmethod
def configurable_base(cls):
return TestConfigurable
@classmethod
def configurable_default(cls):
return TestConfig1
class TestConfig1(TestConfigurable):
def initialize(self, pos_arg=None, a=None):
self.a = a
self.pos_arg = pos_arg
class TestConfig2(TestConfigurable):
def initialize(self, pos_arg=None, b=None):
self.b = b
self.pos_arg = pos_arg
class TestConfig3(TestConfigurable):
# TestConfig3 is a configuration option that is itself configurable.
@classmethod
def configurable_base(cls):
return TestConfig3
@classmethod
def configurable_default(cls):
return TestConfig3A
class TestConfig3A(TestConfig3):
def initialize(self, a=None):
self.a = a
class TestConfig3B(TestConfig3):
def initialize(self, b=None):
self.b = b
class ConfigurableTest(unittest.TestCase):
def setUp(self):
self.saved = TestConfigurable._save_configuration()
self.saved3 = TestConfig3._save_configuration()
def tearDown(self):
TestConfigurable._restore_configuration(self.saved)
TestConfig3._restore_configuration(self.saved3)
def checkSubclasses(self):
# no matter how the class is configured, it should always be
# possible to instantiate the subclasses directly
self.assertIsInstance(TestConfig1(), TestConfig1)
self.assertIsInstance(TestConfig2(), TestConfig2)
obj = TestConfig1(a=1)
self.assertEqual(obj.a, 1)
obj2 = TestConfig2(b=2)
self.assertEqual(obj2.b, 2)
def test_default(self):
# In these tests we combine a typing.cast to satisfy mypy with
# a runtime type-assertion. Without the cast, mypy would only
# let us access attributes of the base class.
obj = cast(TestConfig1, TestConfigurable())
self.assertIsInstance(obj, TestConfig1)
self.assertIs(obj.a, None)
obj = cast(TestConfig1, TestConfigurable(a=1))
self.assertIsInstance(obj, TestConfig1)
self.assertEqual(obj.a, 1)
self.checkSubclasses()
def test_config_class(self):
TestConfigurable.configure(TestConfig2)
obj = cast(TestConfig2, TestConfigurable())
self.assertIsInstance(obj, TestConfig2)
self.assertIs(obj.b, None)
obj = cast(TestConfig2, TestConfigurable(b=2))
self.assertIsInstance(obj, TestConfig2)
self.assertEqual(obj.b, 2)
self.checkSubclasses()
def test_config_str(self):
TestConfigurable.configure("tornado.test.util_test.TestConfig2")
obj = cast(TestConfig2, TestConfigurable())
self.assertIsInstance(obj, TestConfig2)
self.assertIs(obj.b, None)
obj = cast(TestConfig2, TestConfigurable(b=2))
self.assertIsInstance(obj, TestConfig2)
self.assertEqual(obj.b, 2)
self.checkSubclasses()
def test_config_args(self):
TestConfigurable.configure(None, a=3)
obj = cast(TestConfig1, TestConfigurable())
self.assertIsInstance(obj, TestConfig1)
self.assertEqual(obj.a, 3)
obj = cast(TestConfig1, TestConfigurable(42, a=4))
self.assertIsInstance(obj, TestConfig1)
self.assertEqual(obj.a, 4)
self.assertEqual(obj.pos_arg, 42)
self.checkSubclasses()
# args bound in configure don't apply when using the subclass directly
obj = TestConfig1()
self.assertIs(obj.a, None)
def test_config_class_args(self):
TestConfigurable.configure(TestConfig2, b=5)
obj = cast(TestConfig2, TestConfigurable())
self.assertIsInstance(obj, TestConfig2)
self.assertEqual(obj.b, 5)
obj = cast(TestConfig2, TestConfigurable(42, b=6))
self.assertIsInstance(obj, TestConfig2)
self.assertEqual(obj.b, 6)
self.assertEqual(obj.pos_arg, 42)
self.checkSubclasses()
# args bound in configure don't apply when using the subclass directly
obj = TestConfig2()
self.assertIs(obj.b, None)
def test_config_multi_level(self):
TestConfigurable.configure(TestConfig3, a=1)
obj = cast(TestConfig3A, TestConfigurable())
self.assertIsInstance(obj, TestConfig3A)
self.assertEqual(obj.a, 1)
TestConfigurable.configure(TestConfig3)
TestConfig3.configure(TestConfig3B, b=2)
obj2 = cast(TestConfig3B, TestConfigurable())
self.assertIsInstance(obj2, TestConfig3B)
self.assertEqual(obj2.b, 2)
def test_config_inner_level(self):
# The inner level can be used even when the outer level
# doesn't point to it.
obj = TestConfig3()
self.assertIsInstance(obj, TestConfig3A)
TestConfig3.configure(TestConfig3B)
obj = TestConfig3()
self.assertIsInstance(obj, TestConfig3B)
# Configuring the base doesn't configure the inner.
obj2 = TestConfigurable()
self.assertIsInstance(obj2, TestConfig1)
TestConfigurable.configure(TestConfig2)
obj3 = TestConfigurable()
self.assertIsInstance(obj3, TestConfig2)
obj = TestConfig3()
self.assertIsInstance(obj, TestConfig3B)
class UnicodeLiteralTest(unittest.TestCase):
def test_unicode_escapes(self):
self.assertEqual(utf8(u"\u00e9"), b"\xc3\xa9")
class ExecInTest(unittest.TestCase):
# TODO(bdarnell): make a version of this test for one of the new
# future imports available in python 3.
@unittest.skip("no testable future imports")
def test_no_inherit_future(self):
# This file has from __future__ import print_function...
f = StringIO()
print("hello", file=f)
# ...but the template doesn't
exec_in('print >> f, "world"', dict(f=f))
self.assertEqual(f.getvalue(), "hello\nworld\n")
class ArgReplacerTest(unittest.TestCase):
def setUp(self):
def function(x, y, callback=None, z=None):
pass
self.replacer = ArgReplacer(function, "callback")
def test_omitted(self):
args = (1, 2)
kwargs = dict() # type: Dict[str, Any]
self.assertIs(self.replacer.get_old_value(args, kwargs), None)
self.assertEqual(
self.replacer.replace("new", args, kwargs),
(None, (1, 2), dict(callback="new")),
)
def test_position(self):
args = (1, 2, "old", 3)
kwargs = dict() # type: Dict[str, Any]
self.assertEqual(self.replacer.get_old_value(args, kwargs), "old")
self.assertEqual(
self.replacer.replace("new", args, kwargs),
("old", [1, 2, "new", 3], dict()),
)
def test_keyword(self):
args = (1,)
kwargs = dict(y=2, callback="old", z=3)
self.assertEqual(self.replacer.get_old_value(args, kwargs), "old")
self.assertEqual(
self.replacer.replace("new", args, kwargs),
("old", (1,), dict(y=2, callback="new", z=3)),
)
class TimedeltaToSecondsTest(unittest.TestCase):
def test_timedelta_to_seconds(self):
time_delta = datetime.timedelta(hours=1)
self.assertEqual(timedelta_to_seconds(time_delta), 3600.0)
class ImportObjectTest(unittest.TestCase):
def test_import_member(self):
self.assertIs(import_object("tornado.escape.utf8"), utf8)
def test_import_member_unicode(self):
self.assertIs(import_object(u"tornado.escape.utf8"), utf8)
def test_import_module(self):
self.assertIs(import_object("tornado.escape"), tornado.escape)
def test_import_module_unicode(self):
# The internal implementation of __import__ differs depending on
# whether the thing being imported is a module or not.
# This variant requires a byte string in python 2.
self.assertIs(import_object(u"tornado.escape"), tornado.escape)
class ReUnescapeTest(unittest.TestCase):
def test_re_unescape(self):
test_strings = ("/favicon.ico", "index.html", "Hello, World!", "!$@#%;")
for string in test_strings:
self.assertEqual(string, re_unescape(re.escape(string)))
def test_re_unescape_raises_error_on_invalid_input(self):
with self.assertRaises(ValueError):
re_unescape("\\d")
with self.assertRaises(ValueError):
re_unescape("\\b")
with self.assertRaises(ValueError):
re_unescape("\\Z")
class IsFinalizingTest(unittest.TestCase):
def test_basic(self):
self.assertFalse(is_finalizing())

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,822 @@
import asyncio
import functools
import traceback
import unittest
from tornado.concurrent import Future
from tornado import gen
from tornado.httpclient import HTTPError, HTTPRequest
from tornado.locks import Event
from tornado.log import gen_log, app_log
from tornado.simple_httpclient import SimpleAsyncHTTPClient
from tornado.template import DictLoader
from tornado.testing import AsyncHTTPTestCase, gen_test, bind_unused_port, ExpectLog
from tornado.web import Application, RequestHandler
try:
import tornado.websocket # noqa: F401
from tornado.util import _websocket_mask_python
except ImportError:
# The unittest module presents misleading errors on ImportError
# (it acts as if websocket_test could not be found, hiding the underlying
# error). If we get an ImportError here (which could happen due to
# TORNADO_EXTENSION=1), print some extra information before failing.
traceback.print_exc()
raise
from tornado.websocket import (
WebSocketHandler,
websocket_connect,
WebSocketError,
WebSocketClosedError,
)
try:
from tornado import speedups
except ImportError:
speedups = None # type: ignore
class TestWebSocketHandler(WebSocketHandler):
"""Base class for testing handlers that exposes the on_close event.
This allows for tests to see the close code and reason on the
server side.
"""
def initialize(self, close_future=None, compression_options=None):
self.close_future = close_future
self.compression_options = compression_options
def get_compression_options(self):
return self.compression_options
def on_close(self):
if self.close_future is not None:
self.close_future.set_result((self.close_code, self.close_reason))
class EchoHandler(TestWebSocketHandler):
@gen.coroutine
def on_message(self, message):
try:
yield self.write_message(message, isinstance(message, bytes))
except asyncio.CancelledError:
pass
except WebSocketClosedError:
pass
class ErrorInOnMessageHandler(TestWebSocketHandler):
def on_message(self, message):
1 / 0
class HeaderHandler(TestWebSocketHandler):
def open(self):
methods_to_test = [
functools.partial(self.write, "This should not work"),
functools.partial(self.redirect, "http://localhost/elsewhere"),
functools.partial(self.set_header, "X-Test", ""),
functools.partial(self.set_cookie, "Chocolate", "Chip"),
functools.partial(self.set_status, 503),
self.flush,
self.finish,
]
for method in methods_to_test:
try:
# In a websocket context, many RequestHandler methods
# raise RuntimeErrors.
method()
raise Exception("did not get expected exception")
except RuntimeError:
pass
self.write_message(self.request.headers.get("X-Test", ""))
class HeaderEchoHandler(TestWebSocketHandler):
def set_default_headers(self):
self.set_header("X-Extra-Response-Header", "Extra-Response-Value")
def prepare(self):
for k, v in self.request.headers.get_all():
if k.lower().startswith("x-test"):
self.set_header(k, v)
class NonWebSocketHandler(RequestHandler):
def get(self):
self.write("ok")
class CloseReasonHandler(TestWebSocketHandler):
def open(self):
self.on_close_called = False
self.close(1001, "goodbye")
class AsyncPrepareHandler(TestWebSocketHandler):
@gen.coroutine
def prepare(self):
yield gen.moment
def on_message(self, message):
self.write_message(message)
class PathArgsHandler(TestWebSocketHandler):
def open(self, arg):
self.write_message(arg)
class CoroutineOnMessageHandler(TestWebSocketHandler):
def initialize(self, **kwargs):
super(CoroutineOnMessageHandler, self).initialize(**kwargs)
self.sleeping = 0
@gen.coroutine
def on_message(self, message):
if self.sleeping > 0:
self.write_message("another coroutine is already sleeping")
self.sleeping += 1
yield gen.sleep(0.01)
self.sleeping -= 1
self.write_message(message)
class RenderMessageHandler(TestWebSocketHandler):
def on_message(self, message):
self.write_message(self.render_string("message.html", message=message))
class SubprotocolHandler(TestWebSocketHandler):
def initialize(self, **kwargs):
super(SubprotocolHandler, self).initialize(**kwargs)
self.select_subprotocol_called = False
def select_subprotocol(self, subprotocols):
if self.select_subprotocol_called:
raise Exception("select_subprotocol called twice")
self.select_subprotocol_called = True
if "goodproto" in subprotocols:
return "goodproto"
return None
def open(self):
if not self.select_subprotocol_called:
raise Exception("select_subprotocol not called")
self.write_message("subprotocol=%s" % self.selected_subprotocol)
class OpenCoroutineHandler(TestWebSocketHandler):
def initialize(self, test, **kwargs):
super(OpenCoroutineHandler, self).initialize(**kwargs)
self.test = test
self.open_finished = False
@gen.coroutine
def open(self):
yield self.test.message_sent.wait()
yield gen.sleep(0.010)
self.open_finished = True
def on_message(self, message):
if not self.open_finished:
raise Exception("on_message called before open finished")
self.write_message("ok")
class ErrorInOpenHandler(TestWebSocketHandler):
def open(self):
raise Exception("boom")
class ErrorInAsyncOpenHandler(TestWebSocketHandler):
async def open(self):
await asyncio.sleep(0)
raise Exception("boom")
class NoDelayHandler(TestWebSocketHandler):
def open(self):
self.set_nodelay(True)
self.write_message("hello")
class WebSocketBaseTestCase(AsyncHTTPTestCase):
@gen.coroutine
def ws_connect(self, path, **kwargs):
ws = yield websocket_connect(
"ws://127.0.0.1:%d%s" % (self.get_http_port(), path), **kwargs
)
raise gen.Return(ws)
class WebSocketTest(WebSocketBaseTestCase):
def get_app(self):
self.close_future = Future() # type: Future[None]
return Application(
[
("/echo", EchoHandler, dict(close_future=self.close_future)),
("/non_ws", NonWebSocketHandler),
("/header", HeaderHandler, dict(close_future=self.close_future)),
(
"/header_echo",
HeaderEchoHandler,
dict(close_future=self.close_future),
),
(
"/close_reason",
CloseReasonHandler,
dict(close_future=self.close_future),
),
(
"/error_in_on_message",
ErrorInOnMessageHandler,
dict(close_future=self.close_future),
),
(
"/async_prepare",
AsyncPrepareHandler,
dict(close_future=self.close_future),
),
(
"/path_args/(.*)",
PathArgsHandler,
dict(close_future=self.close_future),
),
(
"/coroutine",
CoroutineOnMessageHandler,
dict(close_future=self.close_future),
),
("/render", RenderMessageHandler, dict(close_future=self.close_future)),
(
"/subprotocol",
SubprotocolHandler,
dict(close_future=self.close_future),
),
(
"/open_coroutine",
OpenCoroutineHandler,
dict(close_future=self.close_future, test=self),
),
("/error_in_open", ErrorInOpenHandler),
("/error_in_async_open", ErrorInAsyncOpenHandler),
("/nodelay", NoDelayHandler),
],
template_loader=DictLoader({"message.html": "<b>{{ message }}</b>"}),
)
def get_http_client(self):
# These tests require HTTP/1; force the use of SimpleAsyncHTTPClient.
return SimpleAsyncHTTPClient()
def tearDown(self):
super(WebSocketTest, self).tearDown()
RequestHandler._template_loaders.clear()
def test_http_request(self):
# WS server, HTTP client.
response = self.fetch("/echo")
self.assertEqual(response.code, 400)
def test_missing_websocket_key(self):
response = self.fetch(
"/echo",
headers={
"Connection": "Upgrade",
"Upgrade": "WebSocket",
"Sec-WebSocket-Version": "13",
},
)
self.assertEqual(response.code, 400)
def test_bad_websocket_version(self):
response = self.fetch(
"/echo",
headers={
"Connection": "Upgrade",
"Upgrade": "WebSocket",
"Sec-WebSocket-Version": "12",
},
)
self.assertEqual(response.code, 426)
@gen_test
def test_websocket_gen(self):
ws = yield self.ws_connect("/echo")
yield ws.write_message("hello")
response = yield ws.read_message()
self.assertEqual(response, "hello")
def test_websocket_callbacks(self):
websocket_connect(
"ws://127.0.0.1:%d/echo" % self.get_http_port(), callback=self.stop
)
ws = self.wait().result()
ws.write_message("hello")
ws.read_message(self.stop)
response = self.wait().result()
self.assertEqual(response, "hello")
self.close_future.add_done_callback(lambda f: self.stop())
ws.close()
self.wait()
@gen_test
def test_binary_message(self):
ws = yield self.ws_connect("/echo")
ws.write_message(b"hello \xe9", binary=True)
response = yield ws.read_message()
self.assertEqual(response, b"hello \xe9")
@gen_test
def test_unicode_message(self):
ws = yield self.ws_connect("/echo")
ws.write_message(u"hello \u00e9")
response = yield ws.read_message()
self.assertEqual(response, u"hello \u00e9")
@gen_test
def test_render_message(self):
ws = yield self.ws_connect("/render")
ws.write_message("hello")
response = yield ws.read_message()
self.assertEqual(response, "<b>hello</b>")
@gen_test
def test_error_in_on_message(self):
ws = yield self.ws_connect("/error_in_on_message")
ws.write_message("hello")
with ExpectLog(app_log, "Uncaught exception"):
response = yield ws.read_message()
self.assertIs(response, None)
@gen_test
def test_websocket_http_fail(self):
with self.assertRaises(HTTPError) as cm:
yield self.ws_connect("/notfound")
self.assertEqual(cm.exception.code, 404)
@gen_test
def test_websocket_http_success(self):
with self.assertRaises(WebSocketError):
yield self.ws_connect("/non_ws")
@gen_test
def test_websocket_network_fail(self):
sock, port = bind_unused_port()
sock.close()
with self.assertRaises(IOError):
with ExpectLog(gen_log, ".*"):
yield websocket_connect(
"ws://127.0.0.1:%d/" % port, connect_timeout=3600
)
@gen_test
def test_websocket_close_buffered_data(self):
ws = yield websocket_connect("ws://127.0.0.1:%d/echo" % self.get_http_port())
ws.write_message("hello")
ws.write_message("world")
# Close the underlying stream.
ws.stream.close()
@gen_test
def test_websocket_headers(self):
# Ensure that arbitrary headers can be passed through websocket_connect.
ws = yield websocket_connect(
HTTPRequest(
"ws://127.0.0.1:%d/header" % self.get_http_port(),
headers={"X-Test": "hello"},
)
)
response = yield ws.read_message()
self.assertEqual(response, "hello")
@gen_test
def test_websocket_header_echo(self):
# Ensure that headers can be returned in the response.
# Specifically, that arbitrary headers passed through websocket_connect
# can be returned.
ws = yield websocket_connect(
HTTPRequest(
"ws://127.0.0.1:%d/header_echo" % self.get_http_port(),
headers={"X-Test-Hello": "hello"},
)
)
self.assertEqual(ws.headers.get("X-Test-Hello"), "hello")
self.assertEqual(
ws.headers.get("X-Extra-Response-Header"), "Extra-Response-Value"
)
@gen_test
def test_server_close_reason(self):
ws = yield self.ws_connect("/close_reason")
msg = yield ws.read_message()
# A message of None means the other side closed the connection.
self.assertIs(msg, None)
self.assertEqual(ws.close_code, 1001)
self.assertEqual(ws.close_reason, "goodbye")
# The on_close callback is called no matter which side closed.
code, reason = yield self.close_future
# The client echoed the close code it received to the server,
# so the server's close code (returned via close_future) is
# the same.
self.assertEqual(code, 1001)
@gen_test
def test_client_close_reason(self):
ws = yield self.ws_connect("/echo")
ws.close(1001, "goodbye")
code, reason = yield self.close_future
self.assertEqual(code, 1001)
self.assertEqual(reason, "goodbye")
@gen_test
def test_write_after_close(self):
ws = yield self.ws_connect("/close_reason")
msg = yield ws.read_message()
self.assertIs(msg, None)
with self.assertRaises(WebSocketClosedError):
ws.write_message("hello")
@gen_test
def test_async_prepare(self):
# Previously, an async prepare method triggered a bug that would
# result in a timeout on test shutdown (and a memory leak).
ws = yield self.ws_connect("/async_prepare")
ws.write_message("hello")
res = yield ws.read_message()
self.assertEqual(res, "hello")
@gen_test
def test_path_args(self):
ws = yield self.ws_connect("/path_args/hello")
res = yield ws.read_message()
self.assertEqual(res, "hello")
@gen_test
def test_coroutine(self):
ws = yield self.ws_connect("/coroutine")
# Send both messages immediately, coroutine must process one at a time.
yield ws.write_message("hello1")
yield ws.write_message("hello2")
res = yield ws.read_message()
self.assertEqual(res, "hello1")
res = yield ws.read_message()
self.assertEqual(res, "hello2")
@gen_test
def test_check_origin_valid_no_path(self):
port = self.get_http_port()
url = "ws://127.0.0.1:%d/echo" % port
headers = {"Origin": "http://127.0.0.1:%d" % port}
ws = yield websocket_connect(HTTPRequest(url, headers=headers))
ws.write_message("hello")
response = yield ws.read_message()
self.assertEqual(response, "hello")
@gen_test
def test_check_origin_valid_with_path(self):
port = self.get_http_port()
url = "ws://127.0.0.1:%d/echo" % port
headers = {"Origin": "http://127.0.0.1:%d/something" % port}
ws = yield websocket_connect(HTTPRequest(url, headers=headers))
ws.write_message("hello")
response = yield ws.read_message()
self.assertEqual(response, "hello")
@gen_test
def test_check_origin_invalid_partial_url(self):
port = self.get_http_port()
url = "ws://127.0.0.1:%d/echo" % port
headers = {"Origin": "127.0.0.1:%d" % port}
with self.assertRaises(HTTPError) as cm:
yield websocket_connect(HTTPRequest(url, headers=headers))
self.assertEqual(cm.exception.code, 403)
@gen_test
def test_check_origin_invalid(self):
port = self.get_http_port()
url = "ws://127.0.0.1:%d/echo" % port
# Host is 127.0.0.1, which should not be accessible from some other
# domain
headers = {"Origin": "http://somewhereelse.com"}
with self.assertRaises(HTTPError) as cm:
yield websocket_connect(HTTPRequest(url, headers=headers))
self.assertEqual(cm.exception.code, 403)
@gen_test
def test_check_origin_invalid_subdomains(self):
port = self.get_http_port()
url = "ws://localhost:%d/echo" % port
# Subdomains should be disallowed by default. If we could pass a
# resolver to websocket_connect we could test sibling domains as well.
headers = {"Origin": "http://subtenant.localhost"}
with self.assertRaises(HTTPError) as cm:
yield websocket_connect(HTTPRequest(url, headers=headers))
self.assertEqual(cm.exception.code, 403)
@gen_test
def test_subprotocols(self):
ws = yield self.ws_connect(
"/subprotocol", subprotocols=["badproto", "goodproto"]
)
self.assertEqual(ws.selected_subprotocol, "goodproto")
res = yield ws.read_message()
self.assertEqual(res, "subprotocol=goodproto")
@gen_test
def test_subprotocols_not_offered(self):
ws = yield self.ws_connect("/subprotocol")
self.assertIs(ws.selected_subprotocol, None)
res = yield ws.read_message()
self.assertEqual(res, "subprotocol=None")
@gen_test
def test_open_coroutine(self):
self.message_sent = Event()
ws = yield self.ws_connect("/open_coroutine")
yield ws.write_message("hello")
self.message_sent.set()
res = yield ws.read_message()
self.assertEqual(res, "ok")
@gen_test
def test_error_in_open(self):
with ExpectLog(app_log, "Uncaught exception"):
ws = yield self.ws_connect("/error_in_open")
res = yield ws.read_message()
self.assertIsNone(res)
@gen_test
def test_error_in_async_open(self):
with ExpectLog(app_log, "Uncaught exception"):
ws = yield self.ws_connect("/error_in_async_open")
res = yield ws.read_message()
self.assertIsNone(res)
@gen_test
def test_nodelay(self):
ws = yield self.ws_connect("/nodelay")
res = yield ws.read_message()
self.assertEqual(res, "hello")
class NativeCoroutineOnMessageHandler(TestWebSocketHandler):
def initialize(self, **kwargs):
super().initialize(**kwargs)
self.sleeping = 0
async def on_message(self, message):
if self.sleeping > 0:
self.write_message("another coroutine is already sleeping")
self.sleeping += 1
await gen.sleep(0.01)
self.sleeping -= 1
self.write_message(message)
class WebSocketNativeCoroutineTest(WebSocketBaseTestCase):
def get_app(self):
return Application([("/native", NativeCoroutineOnMessageHandler)])
@gen_test
def test_native_coroutine(self):
ws = yield self.ws_connect("/native")
# Send both messages immediately, coroutine must process one at a time.
yield ws.write_message("hello1")
yield ws.write_message("hello2")
res = yield ws.read_message()
self.assertEqual(res, "hello1")
res = yield ws.read_message()
self.assertEqual(res, "hello2")
class CompressionTestMixin(object):
MESSAGE = "Hello world. Testing 123 123"
def get_app(self):
class LimitedHandler(TestWebSocketHandler):
@property
def max_message_size(self):
return 1024
def on_message(self, message):
self.write_message(str(len(message)))
return Application(
[
(
"/echo",
EchoHandler,
dict(compression_options=self.get_server_compression_options()),
),
(
"/limited",
LimitedHandler,
dict(compression_options=self.get_server_compression_options()),
),
]
)
def get_server_compression_options(self):
return None
def get_client_compression_options(self):
return None
@gen_test
def test_message_sizes(self):
ws = yield self.ws_connect(
"/echo", compression_options=self.get_client_compression_options()
)
# Send the same message three times so we can measure the
# effect of the context_takeover options.
for i in range(3):
ws.write_message(self.MESSAGE)
response = yield ws.read_message()
self.assertEqual(response, self.MESSAGE)
self.assertEqual(ws.protocol._message_bytes_out, len(self.MESSAGE) * 3)
self.assertEqual(ws.protocol._message_bytes_in, len(self.MESSAGE) * 3)
self.verify_wire_bytes(ws.protocol._wire_bytes_in, ws.protocol._wire_bytes_out)
@gen_test
def test_size_limit(self):
ws = yield self.ws_connect(
"/limited", compression_options=self.get_client_compression_options()
)
# Small messages pass through.
ws.write_message("a" * 128)
response = yield ws.read_message()
self.assertEqual(response, "128")
# This message is too big after decompression, but it compresses
# down to a size that will pass the initial checks.
ws.write_message("a" * 2048)
response = yield ws.read_message()
self.assertIsNone(response)
class UncompressedTestMixin(CompressionTestMixin):
"""Specialization of CompressionTestMixin when we expect no compression."""
def verify_wire_bytes(self, bytes_in, bytes_out):
# Bytes out includes the 4-byte mask key per message.
self.assertEqual(bytes_out, 3 * (len(self.MESSAGE) + 6))
self.assertEqual(bytes_in, 3 * (len(self.MESSAGE) + 2))
class NoCompressionTest(UncompressedTestMixin, WebSocketBaseTestCase):
pass
# If only one side tries to compress, the extension is not negotiated.
class ServerOnlyCompressionTest(UncompressedTestMixin, WebSocketBaseTestCase):
def get_server_compression_options(self):
return {}
class ClientOnlyCompressionTest(UncompressedTestMixin, WebSocketBaseTestCase):
def get_client_compression_options(self):
return {}
class DefaultCompressionTest(CompressionTestMixin, WebSocketBaseTestCase):
def get_server_compression_options(self):
return {}
def get_client_compression_options(self):
return {}
def verify_wire_bytes(self, bytes_in, bytes_out):
self.assertLess(bytes_out, 3 * (len(self.MESSAGE) + 6))
self.assertLess(bytes_in, 3 * (len(self.MESSAGE) + 2))
# Bytes out includes the 4 bytes mask key per message.
self.assertEqual(bytes_out, bytes_in + 12)
class MaskFunctionMixin(object):
# Subclasses should define self.mask(mask, data)
def test_mask(self):
self.assertEqual(self.mask(b"abcd", b""), b"")
self.assertEqual(self.mask(b"abcd", b"b"), b"\x03")
self.assertEqual(self.mask(b"abcd", b"54321"), b"TVPVP")
self.assertEqual(self.mask(b"ZXCV", b"98765432"), b"c`t`olpd")
# Include test cases with \x00 bytes (to ensure that the C
# extension isn't depending on null-terminated strings) and
# bytes with the high bit set (to smoke out signedness issues).
self.assertEqual(
self.mask(b"\x00\x01\x02\x03", b"\xff\xfb\xfd\xfc\xfe\xfa"),
b"\xff\xfa\xff\xff\xfe\xfb",
)
self.assertEqual(
self.mask(b"\xff\xfb\xfd\xfc", b"\x00\x01\x02\x03\x04\x05"),
b"\xff\xfa\xff\xff\xfb\xfe",
)
class PythonMaskFunctionTest(MaskFunctionMixin, unittest.TestCase):
def mask(self, mask, data):
return _websocket_mask_python(mask, data)
@unittest.skipIf(speedups is None, "tornado.speedups module not present")
class CythonMaskFunctionTest(MaskFunctionMixin, unittest.TestCase):
def mask(self, mask, data):
return speedups.websocket_mask(mask, data)
class ServerPeriodicPingTest(WebSocketBaseTestCase):
def get_app(self):
class PingHandler(TestWebSocketHandler):
def on_pong(self, data):
self.write_message("got pong")
return Application([("/", PingHandler)], websocket_ping_interval=0.01)
@gen_test
def test_server_ping(self):
ws = yield self.ws_connect("/")
for i in range(3):
response = yield ws.read_message()
self.assertEqual(response, "got pong")
# TODO: test that the connection gets closed if ping responses stop.
class ClientPeriodicPingTest(WebSocketBaseTestCase):
def get_app(self):
class PingHandler(TestWebSocketHandler):
def on_ping(self, data):
self.write_message("got ping")
return Application([("/", PingHandler)])
@gen_test
def test_client_ping(self):
ws = yield self.ws_connect("/", ping_interval=0.01)
for i in range(3):
response = yield ws.read_message()
self.assertEqual(response, "got ping")
# TODO: test that the connection gets closed if ping responses stop.
class ManualPingTest(WebSocketBaseTestCase):
def get_app(self):
class PingHandler(TestWebSocketHandler):
def on_ping(self, data):
self.write_message(data, binary=isinstance(data, bytes))
return Application([("/", PingHandler)])
@gen_test
def test_manual_ping(self):
ws = yield self.ws_connect("/")
self.assertRaises(ValueError, ws.ping, "a" * 126)
ws.ping("hello")
resp = yield ws.read_message()
# on_ping always sees bytes.
self.assertEqual(resp, b"hello")
ws.ping(b"binary hello")
resp = yield ws.read_message()
self.assertEqual(resp, b"binary hello")
class MaxMessageSizeTest(WebSocketBaseTestCase):
def get_app(self):
return Application([("/", EchoHandler)], websocket_max_message_size=1024)
@gen_test
def test_large_message(self):
ws = yield self.ws_connect("/")
# Write a message that is allowed.
msg = "a" * 1024
ws.write_message(msg)
resp = yield ws.read_message()
self.assertEqual(resp, msg)
# Write a message that is too large.
ws.write_message(msg + "b")
resp = yield ws.read_message()
# A message of None means the other side closed the connection.
self.assertIs(resp, None)
self.assertEqual(ws.close_code, 1009)
self.assertEqual(ws.close_reason, "message too big")
# TODO: Needs tests of messages split over multiple
# continuation frames.

View File

@@ -0,0 +1,24 @@
import functools
import os
import socket
import unittest
from tornado.platform.auto import set_close_exec
skipIfNonWindows = unittest.skipIf(os.name != "nt", "non-windows platform")
@skipIfNonWindows
class WindowsTest(unittest.TestCase):
def test_set_close_exec(self):
# set_close_exec works with sockets.
s = socket.socket()
self.addCleanup(s.close)
set_close_exec(s.fileno())
# But it doesn't work with pipes.
r, w = os.pipe()
self.addCleanup(functools.partial(os.close, r))
self.addCleanup(functools.partial(os.close, w))
with self.assertRaises(WindowsError):
set_close_exec(r)

View File

@@ -0,0 +1,20 @@
from wsgiref.validate import validator
from tornado.testing import AsyncHTTPTestCase
from tornado.wsgi import WSGIContainer
class WSGIContainerTest(AsyncHTTPTestCase):
# TODO: Now that WSGIAdapter is gone, this is a pretty weak test.
def wsgi_app(self, environ, start_response):
status = "200 OK"
response_headers = [("Content-Type", "text/plain")]
start_response(status, response_headers)
return [b"Hello world!"]
def get_app(self):
return WSGIContainer(validator(self.wsgi_app))
def test_simple(self):
response = self.fetch("/")
self.assertEqual(response.body, b"Hello world!")

View File

@@ -0,0 +1,790 @@
"""Support classes for automated testing.
* `AsyncTestCase` and `AsyncHTTPTestCase`: Subclasses of unittest.TestCase
with additional support for testing asynchronous (`.IOLoop`-based) code.
* `ExpectLog`: Make test logs less spammy.
* `main()`: A simple test runner (wrapper around unittest.main()) with support
for the tornado.autoreload module to rerun the tests when code changes.
"""
import asyncio
from collections.abc import Generator
import functools
import inspect
import logging
import os
import re
import signal
import socket
import sys
import unittest
from tornado import gen
from tornado.httpclient import AsyncHTTPClient, HTTPResponse
from tornado.httpserver import HTTPServer
from tornado.ioloop import IOLoop, TimeoutError
from tornado import netutil
from tornado.platform.asyncio import AsyncIOMainLoop
from tornado.process import Subprocess
from tornado.log import app_log
from tornado.util import raise_exc_info, basestring_type
from tornado.web import Application
import typing
from typing import Tuple, Any, Callable, Type, Dict, Union, Optional
from types import TracebackType
if typing.TYPE_CHECKING:
# Coroutine wasn't added to typing until 3.5.3, so only import it
# when mypy is running and use forward references.
from typing import Coroutine # noqa: F401
_ExcInfoTuple = Tuple[
Optional[Type[BaseException]], Optional[BaseException], Optional[TracebackType]
]
_NON_OWNED_IOLOOPS = AsyncIOMainLoop
def bind_unused_port(reuse_port: bool = False) -> Tuple[socket.socket, int]:
"""Binds a server socket to an available port on localhost.
Returns a tuple (socket, port).
.. versionchanged:: 4.4
Always binds to ``127.0.0.1`` without resolving the name
``localhost``.
"""
sock = netutil.bind_sockets(
0, "127.0.0.1", family=socket.AF_INET, reuse_port=reuse_port
)[0]
port = sock.getsockname()[1]
return sock, port
def get_async_test_timeout() -> float:
"""Get the global timeout setting for async tests.
Returns a float, the timeout in seconds.
.. versionadded:: 3.1
"""
env = os.environ.get("ASYNC_TEST_TIMEOUT")
if env is not None:
try:
return float(env)
except ValueError:
pass
return 5
class _TestMethodWrapper(object):
"""Wraps a test method to raise an error if it returns a value.
This is mainly used to detect undecorated generators (if a test
method yields it must use a decorator to consume the generator),
but will also detect other kinds of return values (these are not
necessarily errors, but we alert anyway since there is no good
reason to return a value from a test).
"""
def __init__(self, orig_method: Callable) -> None:
self.orig_method = orig_method
def __call__(self, *args: Any, **kwargs: Any) -> None:
result = self.orig_method(*args, **kwargs)
if isinstance(result, Generator) or inspect.iscoroutine(result):
raise TypeError(
"Generator and coroutine test methods should be"
" decorated with tornado.testing.gen_test"
)
elif result is not None:
raise ValueError("Return value from test method ignored: %r" % result)
def __getattr__(self, name: str) -> Any:
"""Proxy all unknown attributes to the original method.
This is important for some of the decorators in the `unittest`
module, such as `unittest.skipIf`.
"""
return getattr(self.orig_method, name)
class AsyncTestCase(unittest.TestCase):
"""`~unittest.TestCase` subclass for testing `.IOLoop`-based
asynchronous code.
The unittest framework is synchronous, so the test must be
complete by the time the test method returns. This means that
asynchronous code cannot be used in quite the same way as usual
and must be adapted to fit. To write your tests with coroutines,
decorate your test methods with `tornado.testing.gen_test` instead
of `tornado.gen.coroutine`.
This class also provides the (deprecated) `stop()` and `wait()`
methods for a more manual style of testing. The test method itself
must call ``self.wait()``, and asynchronous callbacks should call
``self.stop()`` to signal completion.
By default, a new `.IOLoop` is constructed for each test and is available
as ``self.io_loop``. If the code being tested requires a
global `.IOLoop`, subclasses should override `get_new_ioloop` to return it.
The `.IOLoop`'s ``start`` and ``stop`` methods should not be
called directly. Instead, use `self.stop <stop>` and `self.wait
<wait>`. Arguments passed to ``self.stop`` are returned from
``self.wait``. It is possible to have multiple ``wait``/``stop``
cycles in the same test.
Example::
# This test uses coroutine style.
class MyTestCase(AsyncTestCase):
@tornado.testing.gen_test
def test_http_fetch(self):
client = AsyncHTTPClient()
response = yield client.fetch("http://www.tornadoweb.org")
# Test contents of response
self.assertIn("FriendFeed", response.body)
# This test uses argument passing between self.stop and self.wait.
class MyTestCase2(AsyncTestCase):
def test_http_fetch(self):
client = AsyncHTTPClient()
client.fetch("http://www.tornadoweb.org/", self.stop)
response = self.wait()
# Test contents of response
self.assertIn("FriendFeed", response.body)
"""
def __init__(self, methodName: str = "runTest") -> None:
super(AsyncTestCase, self).__init__(methodName)
self.__stopped = False
self.__running = False
self.__failure = None # type: Optional[_ExcInfoTuple]
self.__stop_args = None # type: Any
self.__timeout = None # type: Optional[object]
# It's easy to forget the @gen_test decorator, but if you do
# the test will silently be ignored because nothing will consume
# the generator. Replace the test method with a wrapper that will
# make sure it's not an undecorated generator.
setattr(self, methodName, _TestMethodWrapper(getattr(self, methodName)))
# Not used in this class itself, but used by @gen_test
self._test_generator = None # type: Optional[Union[Generator, Coroutine]]
def setUp(self) -> None:
super(AsyncTestCase, self).setUp()
self.io_loop = self.get_new_ioloop()
self.io_loop.make_current()
def tearDown(self) -> None:
# Native coroutines tend to produce warnings if they're not
# allowed to run to completion. It's difficult to ensure that
# this always happens in tests, so cancel any tasks that are
# still pending by the time we get here.
asyncio_loop = self.io_loop.asyncio_loop # type: ignore
if hasattr(asyncio, "all_tasks"): # py37
tasks = asyncio.all_tasks(asyncio_loop) # type: ignore
else:
tasks = asyncio.Task.all_tasks(asyncio_loop)
# Tasks that are done may still appear here and may contain
# non-cancellation exceptions, so filter them out.
tasks = [t for t in tasks if not t.done()]
for t in tasks:
t.cancel()
# Allow the tasks to run and finalize themselves (which means
# raising a CancelledError inside the coroutine). This may
# just transform the "task was destroyed but it is pending"
# warning into a "uncaught CancelledError" warning, but
# catching CancelledErrors in coroutines that may leak is
# simpler than ensuring that no coroutines leak.
if tasks:
done, pending = self.io_loop.run_sync(lambda: asyncio.wait(tasks))
assert not pending
# If any task failed with anything but a CancelledError, raise it.
for f in done:
try:
f.result()
except asyncio.CancelledError:
pass
# Clean up Subprocess, so it can be used again with a new ioloop.
Subprocess.uninitialize()
self.io_loop.clear_current()
if not isinstance(self.io_loop, _NON_OWNED_IOLOOPS):
# Try to clean up any file descriptors left open in the ioloop.
# This avoids leaks, especially when tests are run repeatedly
# in the same process with autoreload (because curl does not
# set FD_CLOEXEC on its file descriptors)
self.io_loop.close(all_fds=True)
super(AsyncTestCase, self).tearDown()
# In case an exception escaped or the StackContext caught an exception
# when there wasn't a wait() to re-raise it, do so here.
# This is our last chance to raise an exception in a way that the
# unittest machinery understands.
self.__rethrow()
def get_new_ioloop(self) -> IOLoop:
"""Returns the `.IOLoop` to use for this test.
By default, a new `.IOLoop` is created for each test.
Subclasses may override this method to return
`.IOLoop.current()` if it is not appropriate to use a new
`.IOLoop` in each tests (for example, if there are global
singletons using the default `.IOLoop`) or if a per-test event
loop is being provided by another system (such as
``pytest-asyncio``).
"""
return IOLoop()
def _handle_exception(
self, typ: Type[Exception], value: Exception, tb: TracebackType
) -> bool:
if self.__failure is None:
self.__failure = (typ, value, tb)
else:
app_log.error(
"multiple unhandled exceptions in test", exc_info=(typ, value, tb)
)
self.stop()
return True
def __rethrow(self) -> None:
if self.__failure is not None:
failure = self.__failure
self.__failure = None
raise_exc_info(failure)
def run(self, result: unittest.TestResult = None) -> unittest.TestCase:
ret = super(AsyncTestCase, self).run(result)
# As a last resort, if an exception escaped super.run() and wasn't
# re-raised in tearDown, raise it here. This will cause the
# unittest run to fail messily, but that's better than silently
# ignoring an error.
self.__rethrow()
return ret
def stop(self, _arg: Any = None, **kwargs: Any) -> None:
"""Stops the `.IOLoop`, causing one pending (or future) call to `wait()`
to return.
Keyword arguments or a single positional argument passed to `stop()` are
saved and will be returned by `wait()`.
.. deprecated:: 5.1
`stop` and `wait` are deprecated; use ``@gen_test`` instead.
"""
assert _arg is None or not kwargs
self.__stop_args = kwargs or _arg
if self.__running:
self.io_loop.stop()
self.__running = False
self.__stopped = True
def wait(
self, condition: Callable[..., bool] = None, timeout: float = None
) -> None:
"""Runs the `.IOLoop` until stop is called or timeout has passed.
In the event of a timeout, an exception will be thrown. The
default timeout is 5 seconds; it may be overridden with a
``timeout`` keyword argument or globally with the
``ASYNC_TEST_TIMEOUT`` environment variable.
If ``condition`` is not ``None``, the `.IOLoop` will be restarted
after `stop()` until ``condition()`` returns ``True``.
.. versionchanged:: 3.1
Added the ``ASYNC_TEST_TIMEOUT`` environment variable.
.. deprecated:: 5.1
`stop` and `wait` are deprecated; use ``@gen_test`` instead.
"""
if timeout is None:
timeout = get_async_test_timeout()
if not self.__stopped:
if timeout:
def timeout_func() -> None:
try:
raise self.failureException(
"Async operation timed out after %s seconds" % timeout
)
except Exception:
self.__failure = sys.exc_info()
self.stop()
self.__timeout = self.io_loop.add_timeout(
self.io_loop.time() + timeout, timeout_func
)
while True:
self.__running = True
self.io_loop.start()
if self.__failure is not None or condition is None or condition():
break
if self.__timeout is not None:
self.io_loop.remove_timeout(self.__timeout)
self.__timeout = None
assert self.__stopped
self.__stopped = False
self.__rethrow()
result = self.__stop_args
self.__stop_args = None
return result
class AsyncHTTPTestCase(AsyncTestCase):
"""A test case that starts up an HTTP server.
Subclasses must override `get_app()`, which returns the
`tornado.web.Application` (or other `.HTTPServer` callback) to be tested.
Tests will typically use the provided ``self.http_client`` to fetch
URLs from this server.
Example, assuming the "Hello, world" example from the user guide is in
``hello.py``::
import hello
class TestHelloApp(AsyncHTTPTestCase):
def get_app(self):
return hello.make_app()
def test_homepage(self):
response = self.fetch('/')
self.assertEqual(response.code, 200)
self.assertEqual(response.body, 'Hello, world')
That call to ``self.fetch()`` is equivalent to ::
self.http_client.fetch(self.get_url('/'), self.stop)
response = self.wait()
which illustrates how AsyncTestCase can turn an asynchronous operation,
like ``http_client.fetch()``, into a synchronous operation. If you need
to do other asynchronous operations in tests, you'll probably need to use
``stop()`` and ``wait()`` yourself.
"""
def setUp(self) -> None:
super(AsyncHTTPTestCase, self).setUp()
sock, port = bind_unused_port()
self.__port = port
self.http_client = self.get_http_client()
self._app = self.get_app()
self.http_server = self.get_http_server()
self.http_server.add_sockets([sock])
def get_http_client(self) -> AsyncHTTPClient:
return AsyncHTTPClient()
def get_http_server(self) -> HTTPServer:
return HTTPServer(self._app, **self.get_httpserver_options())
def get_app(self) -> Application:
"""Should be overridden by subclasses to return a
`tornado.web.Application` or other `.HTTPServer` callback.
"""
raise NotImplementedError()
def fetch(
self, path: str, raise_error: bool = False, **kwargs: Any
) -> HTTPResponse:
"""Convenience method to synchronously fetch a URL.
The given path will be appended to the local server's host and
port. Any additional keyword arguments will be passed directly to
`.AsyncHTTPClient.fetch` (and so could be used to pass
``method="POST"``, ``body="..."``, etc).
If the path begins with http:// or https://, it will be treated as a
full URL and will be fetched as-is.
If ``raise_error`` is ``True``, a `tornado.httpclient.HTTPError` will
be raised if the response code is not 200. This is the same behavior
as the ``raise_error`` argument to `.AsyncHTTPClient.fetch`, but
the default is ``False`` here (it's ``True`` in `.AsyncHTTPClient`)
because tests often need to deal with non-200 response codes.
.. versionchanged:: 5.0
Added support for absolute URLs.
.. versionchanged:: 5.1
Added the ``raise_error`` argument.
.. deprecated:: 5.1
This method currently turns any exception into an
`.HTTPResponse` with status code 599. In Tornado 6.0,
errors other than `tornado.httpclient.HTTPError` will be
passed through, and ``raise_error=False`` will only
suppress errors that would be raised due to non-200
response codes.
"""
if path.lower().startswith(("http://", "https://")):
url = path
else:
url = self.get_url(path)
return self.io_loop.run_sync(
lambda: self.http_client.fetch(url, raise_error=raise_error, **kwargs),
timeout=get_async_test_timeout(),
)
def get_httpserver_options(self) -> Dict[str, Any]:
"""May be overridden by subclasses to return additional
keyword arguments for the server.
"""
return {}
def get_http_port(self) -> int:
"""Returns the port used by the server.
A new port is chosen for each test.
"""
return self.__port
def get_protocol(self) -> str:
return "http"
def get_url(self, path: str) -> str:
"""Returns an absolute url for the given path on the test server."""
return "%s://127.0.0.1:%s%s" % (self.get_protocol(), self.get_http_port(), path)
def tearDown(self) -> None:
self.http_server.stop()
self.io_loop.run_sync(
self.http_server.close_all_connections, timeout=get_async_test_timeout()
)
self.http_client.close()
del self.http_server
del self._app
super(AsyncHTTPTestCase, self).tearDown()
class AsyncHTTPSTestCase(AsyncHTTPTestCase):
"""A test case that starts an HTTPS server.
Interface is generally the same as `AsyncHTTPTestCase`.
"""
def get_http_client(self) -> AsyncHTTPClient:
return AsyncHTTPClient(force_instance=True, defaults=dict(validate_cert=False))
def get_httpserver_options(self) -> Dict[str, Any]:
return dict(ssl_options=self.get_ssl_options())
def get_ssl_options(self) -> Dict[str, Any]:
"""May be overridden by subclasses to select SSL options.
By default includes a self-signed testing certificate.
"""
return AsyncHTTPSTestCase.default_ssl_options()
@staticmethod
def default_ssl_options() -> Dict[str, Any]:
# Testing keys were generated with:
# openssl req -new -keyout tornado/test/test.key \
# -out tornado/test/test.crt -nodes -days 3650 -x509
module_dir = os.path.dirname(__file__)
return dict(
certfile=os.path.join(module_dir, "test", "test.crt"),
keyfile=os.path.join(module_dir, "test", "test.key"),
)
def get_protocol(self) -> str:
return "https"
@typing.overload
def gen_test(
*, timeout: float = None
) -> Callable[[Callable[..., Union[Generator, "Coroutine"]]], Callable[..., None]]:
pass
@typing.overload # noqa: F811
def gen_test(func: Callable[..., Union[Generator, "Coroutine"]]) -> Callable[..., None]:
pass
def gen_test( # noqa: F811
func: Callable[..., Union[Generator, "Coroutine"]] = None, timeout: float = None
) -> Union[
Callable[..., None],
Callable[[Callable[..., Union[Generator, "Coroutine"]]], Callable[..., None]],
]:
"""Testing equivalent of ``@gen.coroutine``, to be applied to test methods.
``@gen.coroutine`` cannot be used on tests because the `.IOLoop` is not
already running. ``@gen_test`` should be applied to test methods
on subclasses of `AsyncTestCase`.
Example::
class MyTest(AsyncHTTPTestCase):
@gen_test
def test_something(self):
response = yield self.http_client.fetch(self.get_url('/'))
By default, ``@gen_test`` times out after 5 seconds. The timeout may be
overridden globally with the ``ASYNC_TEST_TIMEOUT`` environment variable,
or for each test with the ``timeout`` keyword argument::
class MyTest(AsyncHTTPTestCase):
@gen_test(timeout=10)
def test_something_slow(self):
response = yield self.http_client.fetch(self.get_url('/'))
Note that ``@gen_test`` is incompatible with `AsyncTestCase.stop`,
`AsyncTestCase.wait`, and `AsyncHTTPTestCase.fetch`. Use ``yield
self.http_client.fetch(self.get_url())`` as shown above instead.
.. versionadded:: 3.1
The ``timeout`` argument and ``ASYNC_TEST_TIMEOUT`` environment
variable.
.. versionchanged:: 4.0
The wrapper now passes along ``*args, **kwargs`` so it can be used
on functions with arguments.
"""
if timeout is None:
timeout = get_async_test_timeout()
def wrap(f: Callable[..., Union[Generator, "Coroutine"]]) -> Callable[..., None]:
# Stack up several decorators to allow us to access the generator
# object itself. In the innermost wrapper, we capture the generator
# and save it in an attribute of self. Next, we run the wrapped
# function through @gen.coroutine. Finally, the coroutine is
# wrapped again to make it synchronous with run_sync.
#
# This is a good case study arguing for either some sort of
# extensibility in the gen decorators or cancellation support.
@functools.wraps(f)
def pre_coroutine(self, *args, **kwargs):
# type: (AsyncTestCase, *Any, **Any) -> Union[Generator, Coroutine]
# Type comments used to avoid pypy3 bug.
result = f(self, *args, **kwargs)
if isinstance(result, Generator) or inspect.iscoroutine(result):
self._test_generator = result
else:
self._test_generator = None
return result
if inspect.iscoroutinefunction(f):
coro = pre_coroutine
else:
coro = gen.coroutine(pre_coroutine)
@functools.wraps(coro)
def post_coroutine(self, *args, **kwargs):
# type: (AsyncTestCase, *Any, **Any) -> None
try:
return self.io_loop.run_sync(
functools.partial(coro, self, *args, **kwargs), timeout=timeout
)
except TimeoutError as e:
# run_sync raises an error with an unhelpful traceback.
# If the underlying generator is still running, we can throw the
# exception back into it so the stack trace is replaced by the
# point where the test is stopped. The only reason the generator
# would not be running would be if it were cancelled, which means
# a native coroutine, so we can rely on the cr_running attribute.
if self._test_generator is not None and getattr(
self._test_generator, "cr_running", True
):
self._test_generator.throw(type(e), e)
# In case the test contains an overly broad except
# clause, we may get back here.
# Coroutine was stopped or didn't raise a useful stack trace,
# so re-raise the original exception which is better than nothing.
raise
return post_coroutine
if func is not None:
# Used like:
# @gen_test
# def f(self):
# pass
return wrap(func)
else:
# Used like @gen_test(timeout=10)
return wrap
# Without this attribute, nosetests will try to run gen_test as a test
# anywhere it is imported.
gen_test.__test__ = False # type: ignore
class ExpectLog(logging.Filter):
"""Context manager to capture and suppress expected log output.
Useful to make tests of error conditions less noisy, while still
leaving unexpected log entries visible. *Not thread safe.*
The attribute ``logged_stack`` is set to ``True`` if any exception
stack trace was logged.
Usage::
with ExpectLog('tornado.application', "Uncaught exception"):
error_response = self.fetch("/some_page")
.. versionchanged:: 4.3
Added the ``logged_stack`` attribute.
"""
def __init__(
self,
logger: Union[logging.Logger, basestring_type],
regex: str,
required: bool = True,
) -> None:
"""Constructs an ExpectLog context manager.
:param logger: Logger object (or name of logger) to watch. Pass
an empty string to watch the root logger.
:param regex: Regular expression to match. Any log entries on
the specified logger that match this regex will be suppressed.
:param required: If true, an exception will be raised if the end of
the ``with`` statement is reached without matching any log entries.
"""
if isinstance(logger, basestring_type):
logger = logging.getLogger(logger)
self.logger = logger
self.regex = re.compile(regex)
self.required = required
self.matched = False
self.logged_stack = False
def filter(self, record: logging.LogRecord) -> bool:
if record.exc_info:
self.logged_stack = True
message = record.getMessage()
if self.regex.match(message):
self.matched = True
return False
return True
def __enter__(self) -> "ExpectLog":
self.logger.addFilter(self)
return self
def __exit__(
self,
typ: "Optional[Type[BaseException]]",
value: Optional[BaseException],
tb: Optional[TracebackType],
) -> None:
self.logger.removeFilter(self)
if not typ and self.required and not self.matched:
raise Exception("did not get expected log message")
def main(**kwargs: Any) -> None:
"""A simple test runner.
This test runner is essentially equivalent to `unittest.main` from
the standard library, but adds support for Tornado-style option
parsing and log formatting. It is *not* necessary to use this
`main` function to run tests using `AsyncTestCase`; these tests
are self-contained and can run with any test runner.
The easiest way to run a test is via the command line::
python -m tornado.testing tornado.test.web_test
See the standard library ``unittest`` module for ways in which
tests can be specified.
Projects with many tests may wish to define a test script like
``tornado/test/runtests.py``. This script should define a method
``all()`` which returns a test suite and then call
`tornado.testing.main()`. Note that even when a test script is
used, the ``all()`` test suite may be overridden by naming a
single test on the command line::
# Runs all tests
python -m tornado.test.runtests
# Runs one test
python -m tornado.test.runtests tornado.test.web_test
Additional keyword arguments passed through to ``unittest.main()``.
For example, use ``tornado.testing.main(verbosity=2)``
to show many test details as they are run.
See http://docs.python.org/library/unittest.html#unittest.main
for full argument list.
.. versionchanged:: 5.0
This function produces no output of its own; only that produced
by the `unittest` module (previously it would add a PASS or FAIL
log message).
"""
from tornado.options import define, options, parse_command_line
define(
"exception_on_interrupt",
type=bool,
default=True,
help=(
"If true (default), ctrl-c raises a KeyboardInterrupt "
"exception. This prints a stack trace but cannot interrupt "
"certain operations. If false, the process is more reliably "
"killed, but does not print a stack trace."
),
)
# support the same options as unittest's command-line interface
define("verbose", type=bool)
define("quiet", type=bool)
define("failfast", type=bool)
define("catch", type=bool)
define("buffer", type=bool)
argv = [sys.argv[0]] + parse_command_line(sys.argv)
if not options.exception_on_interrupt:
signal.signal(signal.SIGINT, signal.SIG_DFL)
if options.verbose is not None:
kwargs["verbosity"] = 2
if options.quiet is not None:
kwargs["verbosity"] = 0
if options.failfast is not None:
kwargs["failfast"] = True
if options.catch is not None:
kwargs["catchbreak"] = True
if options.buffer is not None:
kwargs["buffer"] = True
if __name__ == "__main__" and len(argv) == 1:
print("No tests specified", file=sys.stderr)
sys.exit(1)
# In order to be able to run tests by their fully-qualified name
# on the command line without importing all tests here,
# module must be set to None. Python 3.2's unittest.main ignores
# defaultTest if no module is given (it tries to do its own
# test discovery, which is incompatible with auto2to3), so don't
# set module if we're not asking for a specific test.
if len(argv) > 1:
unittest.main(module=None, argv=argv, **kwargs) # type: ignore
else:
unittest.main(defaultTest="all", argv=argv, **kwargs)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,472 @@
"""Miscellaneous utility functions and classes.
This module is used internally by Tornado. It is not necessarily expected
that the functions and classes defined here will be useful to other
applications, but they are documented here in case they are.
The one public-facing part of this module is the `Configurable` class
and its `~Configurable.configure` method, which becomes a part of the
interface of its subclasses, including `.AsyncHTTPClient`, `.IOLoop`,
and `.Resolver`.
"""
import array
import atexit
from inspect import getfullargspec
import os
import re
import typing
import zlib
from typing import (
Any,
Optional,
Dict,
Mapping,
List,
Tuple,
Match,
Callable,
Type,
Sequence,
)
if typing.TYPE_CHECKING:
# Additional imports only used in type comments.
# This lets us make these imports lazy.
import datetime # noqa: F401
from types import TracebackType # noqa: F401
from typing import Union # noqa: F401
import unittest # noqa: F401
# Aliases for types that are spelled differently in different Python
# versions. bytes_type is deprecated and no longer used in Tornado
# itself but is left in case anyone outside Tornado is using it.
bytes_type = bytes
unicode_type = str
basestring_type = str
try:
from sys import is_finalizing
except ImportError:
# Emulate it
def _get_emulated_is_finalizing() -> Callable[[], bool]:
L = [] # type: List[None]
atexit.register(lambda: L.append(None))
def is_finalizing() -> bool:
# Not referencing any globals here
return L != []
return is_finalizing
is_finalizing = _get_emulated_is_finalizing()
class TimeoutError(Exception):
"""Exception raised by `.with_timeout` and `.IOLoop.run_sync`.
.. versionchanged:: 5.0:
Unified ``tornado.gen.TimeoutError`` and
``tornado.ioloop.TimeoutError`` as ``tornado.util.TimeoutError``.
Both former names remain as aliases.
"""
class ObjectDict(Dict[str, Any]):
"""Makes a dictionary behave like an object, with attribute-style access.
"""
def __getattr__(self, name: str) -> Any:
try:
return self[name]
except KeyError:
raise AttributeError(name)
def __setattr__(self, name: str, value: Any) -> None:
self[name] = value
class GzipDecompressor(object):
"""Streaming gzip decompressor.
The interface is like that of `zlib.decompressobj` (without some of the
optional arguments, but it understands gzip headers and checksums.
"""
def __init__(self) -> None:
# Magic parameter makes zlib module understand gzip header
# http://stackoverflow.com/questions/1838699/how-can-i-decompress-a-gzip-stream-with-zlib
# This works on cpython and pypy, but not jython.
self.decompressobj = zlib.decompressobj(16 + zlib.MAX_WBITS)
def decompress(self, value: bytes, max_length: int = 0) -> bytes:
"""Decompress a chunk, returning newly-available data.
Some data may be buffered for later processing; `flush` must
be called when there is no more input data to ensure that
all data was processed.
If ``max_length`` is given, some input data may be left over
in ``unconsumed_tail``; you must retrieve this value and pass
it back to a future call to `decompress` if it is not empty.
"""
return self.decompressobj.decompress(value, max_length)
@property
def unconsumed_tail(self) -> bytes:
"""Returns the unconsumed portion left over
"""
return self.decompressobj.unconsumed_tail
def flush(self) -> bytes:
"""Return any remaining buffered data not yet returned by decompress.
Also checks for errors such as truncated input.
No other methods may be called on this object after `flush`.
"""
return self.decompressobj.flush()
def import_object(name: str) -> Any:
"""Imports an object by name.
``import_object('x')`` is equivalent to ``import x``.
``import_object('x.y.z')`` is equivalent to ``from x.y import z``.
>>> import tornado.escape
>>> import_object('tornado.escape') is tornado.escape
True
>>> import_object('tornado.escape.utf8') is tornado.escape.utf8
True
>>> import_object('tornado') is tornado
True
>>> import_object('tornado.missing_module')
Traceback (most recent call last):
...
ImportError: No module named missing_module
"""
if name.count(".") == 0:
return __import__(name)
parts = name.split(".")
obj = __import__(".".join(parts[:-1]), fromlist=[parts[-1]])
try:
return getattr(obj, parts[-1])
except AttributeError:
raise ImportError("No module named %s" % parts[-1])
def exec_in(code: Any, glob: Dict[str, Any], loc: Mapping[str, Any] = None) -> None:
if isinstance(code, str):
# exec(string) inherits the caller's future imports; compile
# the string first to prevent that.
code = compile(code, "<string>", "exec", dont_inherit=True)
exec(code, glob, loc)
def raise_exc_info(
exc_info, # type: Tuple[Optional[type], Optional[BaseException], Optional[TracebackType]]
):
# type: (...) -> typing.NoReturn
#
# This function's type annotation must use comments instead of
# real annotations because typing.NoReturn does not exist in
# python 3.5's typing module. The formatting is funky because this
# is apparently what flake8 wants.
try:
if exc_info[1] is not None:
raise exc_info[1].with_traceback(exc_info[2])
else:
raise TypeError("raise_exc_info called with no exception")
finally:
# Clear the traceback reference from our stack frame to
# minimize circular references that slow down GC.
exc_info = (None, None, None)
def errno_from_exception(e: BaseException) -> Optional[int]:
"""Provides the errno from an Exception object.
There are cases that the errno attribute was not set so we pull
the errno out of the args but if someone instantiates an Exception
without any args you will get a tuple error. So this function
abstracts all that behavior to give you a safe way to get the
errno.
"""
if hasattr(e, "errno"):
return e.errno # type: ignore
elif e.args:
return e.args[0]
else:
return None
_alphanum = frozenset("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789")
def _re_unescape_replacement(match: Match[str]) -> str:
group = match.group(1)
if group[0] in _alphanum:
raise ValueError("cannot unescape '\\\\%s'" % group[0])
return group
_re_unescape_pattern = re.compile(r"\\(.)", re.DOTALL)
def re_unescape(s: str) -> str:
r"""Unescape a string escaped by `re.escape`.
May raise ``ValueError`` for regular expressions which could not
have been produced by `re.escape` (for example, strings containing
``\d`` cannot be unescaped).
.. versionadded:: 4.4
"""
return _re_unescape_pattern.sub(_re_unescape_replacement, s)
class Configurable(object):
"""Base class for configurable interfaces.
A configurable interface is an (abstract) class whose constructor
acts as a factory function for one of its implementation subclasses.
The implementation subclass as well as optional keyword arguments to
its initializer can be set globally at runtime with `configure`.
By using the constructor as the factory method, the interface
looks like a normal class, `isinstance` works as usual, etc. This
pattern is most useful when the choice of implementation is likely
to be a global decision (e.g. when `~select.epoll` is available,
always use it instead of `~select.select`), or when a
previously-monolithic class has been split into specialized
subclasses.
Configurable subclasses must define the class methods
`configurable_base` and `configurable_default`, and use the instance
method `initialize` instead of ``__init__``.
.. versionchanged:: 5.0
It is now possible for configuration to be specified at
multiple levels of a class hierarchy.
"""
# Type annotations on this class are mostly done with comments
# because they need to refer to Configurable, which isn't defined
# until after the class definition block. These can use regular
# annotations when our minimum python version is 3.7.
#
# There may be a clever way to use generics here to get more
# precise types (i.e. for a particular Configurable subclass T,
# all the types are subclasses of T, not just Configurable).
__impl_class = None # type: Optional[Type[Configurable]]
__impl_kwargs = None # type: Dict[str, Any]
def __new__(cls, *args: Any, **kwargs: Any) -> Any:
base = cls.configurable_base()
init_kwargs = {} # type: Dict[str, Any]
if cls is base:
impl = cls.configured_class()
if base.__impl_kwargs:
init_kwargs.update(base.__impl_kwargs)
else:
impl = cls
init_kwargs.update(kwargs)
if impl.configurable_base() is not base:
# The impl class is itself configurable, so recurse.
return impl(*args, **init_kwargs)
instance = super(Configurable, cls).__new__(impl)
# initialize vs __init__ chosen for compatibility with AsyncHTTPClient
# singleton magic. If we get rid of that we can switch to __init__
# here too.
instance.initialize(*args, **init_kwargs)
return instance
@classmethod
def configurable_base(cls):
# type: () -> Type[Configurable]
"""Returns the base class of a configurable hierarchy.
This will normally return the class in which it is defined.
(which is *not* necessarily the same as the ``cls`` classmethod
parameter).
"""
raise NotImplementedError()
@classmethod
def configurable_default(cls):
# type: () -> Type[Configurable]
"""Returns the implementation class to be used if none is configured."""
raise NotImplementedError()
def _initialize(self) -> None:
pass
initialize = _initialize # type: Callable[..., None]
"""Initialize a `Configurable` subclass instance.
Configurable classes should use `initialize` instead of ``__init__``.
.. versionchanged:: 4.2
Now accepts positional arguments in addition to keyword arguments.
"""
@classmethod
def configure(cls, impl, **kwargs):
# type: (Union[None, str, Type[Configurable]], Any) -> None
"""Sets the class to use when the base class is instantiated.
Keyword arguments will be saved and added to the arguments passed
to the constructor. This can be used to set global defaults for
some parameters.
"""
base = cls.configurable_base()
if isinstance(impl, str):
impl = typing.cast(Type[Configurable], import_object(impl))
if impl is not None and not issubclass(impl, cls):
raise ValueError("Invalid subclass of %s" % cls)
base.__impl_class = impl
base.__impl_kwargs = kwargs
@classmethod
def configured_class(cls):
# type: () -> Type[Configurable]
"""Returns the currently configured class."""
base = cls.configurable_base()
# Manually mangle the private name to see whether this base
# has been configured (and not another base higher in the
# hierarchy).
if base.__dict__.get("_Configurable__impl_class") is None:
base.__impl_class = cls.configurable_default()
if base.__impl_class is not None:
return base.__impl_class
else:
# Should be impossible, but mypy wants an explicit check.
raise ValueError("configured class not found")
@classmethod
def _save_configuration(cls):
# type: () -> Tuple[Optional[Type[Configurable]], Dict[str, Any]]
base = cls.configurable_base()
return (base.__impl_class, base.__impl_kwargs)
@classmethod
def _restore_configuration(cls, saved):
# type: (Tuple[Optional[Type[Configurable]], Dict[str, Any]]) -> None
base = cls.configurable_base()
base.__impl_class = saved[0]
base.__impl_kwargs = saved[1]
class ArgReplacer(object):
"""Replaces one value in an ``args, kwargs`` pair.
Inspects the function signature to find an argument by name
whether it is passed by position or keyword. For use in decorators
and similar wrappers.
"""
def __init__(self, func: Callable, name: str) -> None:
self.name = name
try:
self.arg_pos = self._getargnames(func).index(name) # type: Optional[int]
except ValueError:
# Not a positional parameter
self.arg_pos = None
def _getargnames(self, func: Callable) -> List[str]:
try:
return getfullargspec(func).args
except TypeError:
if hasattr(func, "func_code"):
# Cython-generated code has all the attributes needed
# by inspect.getfullargspec, but the inspect module only
# works with ordinary functions. Inline the portion of
# getfullargspec that we need here. Note that for static
# functions the @cython.binding(True) decorator must
# be used (for methods it works out of the box).
code = func.func_code # type: ignore
return code.co_varnames[: code.co_argcount]
raise
def get_old_value(
self, args: Sequence[Any], kwargs: Dict[str, Any], default: Any = None
) -> Any:
"""Returns the old value of the named argument without replacing it.
Returns ``default`` if the argument is not present.
"""
if self.arg_pos is not None and len(args) > self.arg_pos:
return args[self.arg_pos]
else:
return kwargs.get(self.name, default)
def replace(
self, new_value: Any, args: Sequence[Any], kwargs: Dict[str, Any]
) -> Tuple[Any, Sequence[Any], Dict[str, Any]]:
"""Replace the named argument in ``args, kwargs`` with ``new_value``.
Returns ``(old_value, args, kwargs)``. The returned ``args`` and
``kwargs`` objects may not be the same as the input objects, or
the input objects may be mutated.
If the named argument was not found, ``new_value`` will be added
to ``kwargs`` and None will be returned as ``old_value``.
"""
if self.arg_pos is not None and len(args) > self.arg_pos:
# The arg to replace is passed positionally
old_value = args[self.arg_pos]
args = list(args) # *args is normally a tuple
args[self.arg_pos] = new_value
else:
# The arg to replace is either omitted or passed by keyword.
old_value = kwargs.get(self.name)
kwargs[self.name] = new_value
return old_value, args, kwargs
def timedelta_to_seconds(td):
# type: (datetime.timedelta) -> float
"""Equivalent to ``td.total_seconds()`` (introduced in Python 2.7)."""
return td.total_seconds()
def _websocket_mask_python(mask: bytes, data: bytes) -> bytes:
"""Websocket masking function.
`mask` is a `bytes` object of length 4; `data` is a `bytes` object of any length.
Returns a `bytes` object of the same length as `data` with the mask applied
as specified in section 5.3 of RFC 6455.
This pure-python implementation may be replaced by an optimized version when available.
"""
mask_arr = array.array("B", mask)
unmasked_arr = array.array("B", data)
for i in range(len(data)):
unmasked_arr[i] = unmasked_arr[i] ^ mask_arr[i % 4]
return unmasked_arr.tobytes()
if os.environ.get("TORNADO_NO_EXTENSION") or os.environ.get("TORNADO_EXTENSION") == "0":
# These environment variables exist to make it easier to do performance
# comparisons; they are not guaranteed to remain supported in the future.
_websocket_mask = _websocket_mask_python
else:
try:
from tornado.speedups import websocket_mask as _websocket_mask
except ImportError:
if os.environ.get("TORNADO_EXTENSION") == "1":
raise
_websocket_mask = _websocket_mask_python
def doctests():
# type: () -> unittest.TestSuite
import doctest
return doctest.DocTestSuite()

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,199 @@
#
# Copyright 2009 Facebook
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""WSGI support for the Tornado web framework.
WSGI is the Python standard for web servers, and allows for interoperability
between Tornado and other Python web frameworks and servers.
This module provides WSGI support via the `WSGIContainer` class, which
makes it possible to run applications using other WSGI frameworks on
the Tornado HTTP server. The reverse is not supported; the Tornado
`.Application` and `.RequestHandler` classes are designed for use with
the Tornado `.HTTPServer` and cannot be used in a generic WSGI
container.
"""
import sys
from io import BytesIO
import tornado
from tornado import escape
from tornado import httputil
from tornado.log import access_log
from typing import List, Tuple, Optional, Callable, Any, Dict, Text
from types import TracebackType
import typing
if typing.TYPE_CHECKING:
from typing import Type # noqa: F401
from wsgiref.types import WSGIApplication as WSGIAppType # noqa: F401
# PEP 3333 specifies that WSGI on python 3 generally deals with byte strings
# that are smuggled inside objects of type unicode (via the latin1 encoding).
# This function is like those in the tornado.escape module, but defined
# here to minimize the temptation to use it in non-wsgi contexts.
def to_wsgi_str(s: bytes) -> str:
assert isinstance(s, bytes)
return s.decode("latin1")
class WSGIContainer(object):
r"""Makes a WSGI-compatible function runnable on Tornado's HTTP server.
.. warning::
WSGI is a *synchronous* interface, while Tornado's concurrency model
is based on single-threaded asynchronous execution. This means that
running a WSGI app with Tornado's `WSGIContainer` is *less scalable*
than running the same app in a multi-threaded WSGI server like
``gunicorn`` or ``uwsgi``. Use `WSGIContainer` only when there are
benefits to combining Tornado and WSGI in the same process that
outweigh the reduced scalability.
Wrap a WSGI function in a `WSGIContainer` and pass it to `.HTTPServer` to
run it. For example::
def simple_app(environ, start_response):
status = "200 OK"
response_headers = [("Content-type", "text/plain")]
start_response(status, response_headers)
return ["Hello world!\n"]
container = tornado.wsgi.WSGIContainer(simple_app)
http_server = tornado.httpserver.HTTPServer(container)
http_server.listen(8888)
tornado.ioloop.IOLoop.current().start()
This class is intended to let other frameworks (Django, web.py, etc)
run on the Tornado HTTP server and I/O loop.
The `tornado.web.FallbackHandler` class is often useful for mixing
Tornado and WSGI apps in the same server. See
https://github.com/bdarnell/django-tornado-demo for a complete example.
"""
def __init__(self, wsgi_application: "WSGIAppType") -> None:
self.wsgi_application = wsgi_application
def __call__(self, request: httputil.HTTPServerRequest) -> None:
data = {} # type: Dict[str, Any]
response = [] # type: List[bytes]
def start_response(
status: str,
headers: List[Tuple[str, str]],
exc_info: Optional[
Tuple[
"Optional[Type[BaseException]]",
Optional[BaseException],
Optional[TracebackType],
]
] = None,
) -> Callable[[bytes], Any]:
data["status"] = status
data["headers"] = headers
return response.append
app_response = self.wsgi_application(
WSGIContainer.environ(request), start_response
)
try:
response.extend(app_response)
body = b"".join(response)
finally:
if hasattr(app_response, "close"):
app_response.close() # type: ignore
if not data:
raise Exception("WSGI app did not call start_response")
status_code_str, reason = data["status"].split(" ", 1)
status_code = int(status_code_str)
headers = data["headers"] # type: List[Tuple[str, str]]
header_set = set(k.lower() for (k, v) in headers)
body = escape.utf8(body)
if status_code != 304:
if "content-length" not in header_set:
headers.append(("Content-Length", str(len(body))))
if "content-type" not in header_set:
headers.append(("Content-Type", "text/html; charset=UTF-8"))
if "server" not in header_set:
headers.append(("Server", "TornadoServer/%s" % tornado.version))
start_line = httputil.ResponseStartLine("HTTP/1.1", status_code, reason)
header_obj = httputil.HTTPHeaders()
for key, value in headers:
header_obj.add(key, value)
assert request.connection is not None
request.connection.write_headers(start_line, header_obj, chunk=body)
request.connection.finish()
self._log(status_code, request)
@staticmethod
def environ(request: httputil.HTTPServerRequest) -> Dict[Text, Any]:
"""Converts a `tornado.httputil.HTTPServerRequest` to a WSGI environment.
"""
hostport = request.host.split(":")
if len(hostport) == 2:
host = hostport[0]
port = int(hostport[1])
else:
host = request.host
port = 443 if request.protocol == "https" else 80
environ = {
"REQUEST_METHOD": request.method,
"SCRIPT_NAME": "",
"PATH_INFO": to_wsgi_str(
escape.url_unescape(request.path, encoding=None, plus=False)
),
"QUERY_STRING": request.query,
"REMOTE_ADDR": request.remote_ip,
"SERVER_NAME": host,
"SERVER_PORT": str(port),
"SERVER_PROTOCOL": request.version,
"wsgi.version": (1, 0),
"wsgi.url_scheme": request.protocol,
"wsgi.input": BytesIO(escape.utf8(request.body)),
"wsgi.errors": sys.stderr,
"wsgi.multithread": False,
"wsgi.multiprocess": True,
"wsgi.run_once": False,
}
if "Content-Type" in request.headers:
environ["CONTENT_TYPE"] = request.headers.pop("Content-Type")
if "Content-Length" in request.headers:
environ["CONTENT_LENGTH"] = request.headers.pop("Content-Length")
for key, value in request.headers.items():
environ["HTTP_" + key.replace("-", "_").upper()] = value
return environ
def _log(self, status_code: int, request: httputil.HTTPServerRequest) -> None:
if status_code < 400:
log_method = access_log.info
elif status_code < 500:
log_method = access_log.warning
else:
log_method = access_log.error
request_time = 1000.0 * request.request_time()
assert request.method is not None
assert request.uri is not None
summary = request.method + " " + request.uri + " (" + request.remote_ip + ")"
log_method("%d %s %.2fms", status_code, summary, request_time)
HTTPRequest = httputil.HTTPServerRequest