Upgraded vendored Python dependencies to the latest versions and removed the unused dependencies.

This commit is contained in:
morpheus65535 2022-01-23 23:07:52 -05:00
parent 36bf0d219d
commit 0c3c5a02a7
2108 changed files with 306789 additions and 151391 deletions

View file

@ -52,7 +52,10 @@ def refine_from_ffprobe(path, video):
if isinstance(data['ffprobe']['video'][0]['frame_rate'], float):
video.fps = data['ffprobe']['video'][0]['frame_rate']
else:
video.fps = data['ffprobe']['video'][0]['frame_rate'].magnitude
try:
video.fps = data['ffprobe']['video'][0]['frame_rate'].magnitude
except AttributeError:
video.fps = data['ffprobe']['video'][0]['frame_rate']
if 'audio' not in data['ffprobe']:
logging.debug('BAZARR FFprobe was unable to find audio tracks in the file!')

View file

@ -184,9 +184,6 @@ def init_binaries():
except Exception:
logging.debug("custom check failed for: %s", exe)
rarfile.OPEN_ARGS = rarfile.ORIG_OPEN_ARGS
rarfile.EXTRACT_ARGS = rarfile.ORIG_EXTRACT_ARGS
rarfile.TEST_ARGS = rarfile.ORIG_TEST_ARGS
logging.debug("Using UnRAR from: %s", exe)
unrar = exe

View file

@ -7,6 +7,8 @@ import platform
import warnings
from logging.handlers import TimedRotatingFileHandler
from pytz_deprecation_shim import PytzUsageWarning
from get_args import args
from config import settings
@ -55,6 +57,7 @@ class NoExceptionFormatter(logging.Formatter):
def configure_logging(debug=False):
warnings.simplefilter('ignore', category=ResourceWarning)
warnings.simplefilter('ignore', category=PytzUsageWarning)
if not debug:
log_level = "INFO"

View file

@ -1,16 +0,0 @@
try:
import ast
from _markerlib.markers import default_environment, compile, interpret
except ImportError:
if 'ast' in globals():
raise
def default_environment():
return {}
def compile(marker):
def marker_fn(environment=None, override=None):
# 'empty markers are True' heuristic won't install extra deps.
return not marker.strip()
marker_fn.__doc__ = marker
return marker_fn
def interpret(marker, environment=None, override=None):
return compile(marker)()

View file

@ -1,119 +0,0 @@
# -*- coding: utf-8 -*-
"""Interpret PEP 345 environment markers.
EXPR [in|==|!=|not in] EXPR [or|and] ...
where EXPR belongs to any of those:
python_version = '%s.%s' % (sys.version_info[0], sys.version_info[1])
python_full_version = sys.version.split()[0]
os.name = os.name
sys.platform = sys.platform
platform.version = platform.version()
platform.machine = platform.machine()
platform.python_implementation = platform.python_implementation()
a free string, like '2.6', or 'win32'
"""
__all__ = ['default_environment', 'compile', 'interpret']
import ast
import os
import platform
import sys
import weakref
_builtin_compile = compile
try:
from platform import python_implementation
except ImportError:
if os.name == "java":
# Jython 2.5 has ast module, but not platform.python_implementation() function.
def python_implementation():
return "Jython"
else:
raise
# restricted set of variables
_VARS = {'sys.platform': sys.platform,
'python_version': '%s.%s' % sys.version_info[:2],
# FIXME parsing sys.platform is not reliable, but there is no other
# way to get e.g. 2.7.2+, and the PEP is defined with sys.version
'python_full_version': sys.version.split(' ', 1)[0],
'os.name': os.name,
'platform.version': platform.version(),
'platform.machine': platform.machine(),
'platform.python_implementation': python_implementation(),
'extra': None # wheel extension
}
for var in list(_VARS.keys()):
if '.' in var:
_VARS[var.replace('.', '_')] = _VARS[var]
def default_environment():
"""Return copy of default PEP 385 globals dictionary."""
return dict(_VARS)
class ASTWhitelist(ast.NodeTransformer):
def __init__(self, statement):
self.statement = statement # for error messages
ALLOWED = (ast.Compare, ast.BoolOp, ast.Attribute, ast.Name, ast.Load, ast.Str)
# Bool operations
ALLOWED += (ast.And, ast.Or)
# Comparison operations
ALLOWED += (ast.Eq, ast.Gt, ast.GtE, ast.In, ast.Is, ast.IsNot, ast.Lt, ast.LtE, ast.NotEq, ast.NotIn)
def visit(self, node):
"""Ensure statement only contains allowed nodes."""
if not isinstance(node, self.ALLOWED):
raise SyntaxError('Not allowed in environment markers.\n%s\n%s' %
(self.statement,
(' ' * node.col_offset) + '^'))
return ast.NodeTransformer.visit(self, node)
def visit_Attribute(self, node):
"""Flatten one level of attribute access."""
new_node = ast.Name("%s.%s" % (node.value.id, node.attr), node.ctx)
return ast.copy_location(new_node, node)
def parse_marker(marker):
tree = ast.parse(marker, mode='eval')
new_tree = ASTWhitelist(marker).generic_visit(tree)
return new_tree
def compile_marker(parsed_marker):
return _builtin_compile(parsed_marker, '<environment marker>', 'eval',
dont_inherit=True)
_cache = weakref.WeakValueDictionary()
def compile(marker):
"""Return compiled marker as a function accepting an environment dict."""
try:
return _cache[marker]
except KeyError:
pass
if not marker.strip():
def marker_fn(environment=None, override=None):
""""""
return True
else:
compiled_marker = compile_marker(parse_marker(marker))
def marker_fn(environment=None, override=None):
"""override updates environment"""
if override is None:
override = {}
if environment is None:
environment = default_environment()
environment.update(override)
return eval(compiled_marker, environment)
marker_fn.__doc__ = marker
_cache[marker] = marker_fn
return _cache[marker]
def interpret(marker, environment=None):
return compile(marker)(environment)

View file

@ -13,8 +13,8 @@ See <http://github.com/ActiveState/appdirs> for details and usage.
# - Mac OS X: http://developer.apple.com/documentation/MacOSX/Conceptual/BPFileSystem/index.html
# - XDG spec for Un*x: http://standards.freedesktop.org/basedir-spec/basedir-spec-latest.html
__version_info__ = (1, 4, 3)
__version__ = '.'.join(map(str, __version_info__))
__version__ = "1.4.4"
__version_info__ = tuple(int(segment) for segment in __version__.split("."))
import sys
@ -98,7 +98,7 @@ def user_data_dir(appname=None, appauthor=None, version=None, roaming=False):
def site_data_dir(appname=None, appauthor=None, version=None, multipath=False):
"""Return full path to the user-shared data dir for this application.
r"""Return full path to the user-shared data dir for this application.
"appname" is the name of application.
If None, just the system directory is returned.
@ -204,7 +204,7 @@ def user_config_dir(appname=None, appauthor=None, version=None, roaming=False):
def site_config_dir(appname=None, appauthor=None, version=None, multipath=False):
"""Return full path to the user-shared data dir for this application.
r"""Return full path to the user-shared data dir for this application.
"appname" is the name of application.
If None, just the system directory is returned.

View file

@ -1,116 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright (c) 2014-2016 The arghelper developers. All rights reserved.
# Project site: https://github.com/questrail/arghelper
# Use of this source code is governed by a MIT-style license that
# can be found in the LICENSE.txt file for the project.
"""Provide helper functions for argparse
"""
# Try to future proof code so that it's Python 3.x ready
from __future__ import print_function
from __future__ import unicode_literals
from __future__ import division
from __future__ import absolute_import
# Standard module imports
import argparse
import sys
import os
def extant_file(arg):
"""Facade for extant_item(arg, arg_type="file")
"""
return extant_item(arg, "file")
def extant_dir(arg):
"""Facade for extant_item(arg, arg_type="directory")
"""
return extant_item(arg, "directory")
def extant_item(arg, arg_type):
"""Determine if parser argument is an existing file or directory.
This technique comes from http://stackoverflow.com/a/11541450/95592
and from http://stackoverflow.com/a/11541495/95592
Args:
arg: parser argument containing filename to be checked
arg_type: string of either "file" or "directory"
Returns:
If the file exists, return the filename or directory.
Raises:
If the file does not exist, raise a parser error.
"""
if arg_type == "file":
if not os.path.isfile(arg):
raise argparse.ArgumentError(
None,
"The file {arg} does not exist.".format(arg=arg))
else:
# File exists so return the filename
return arg
elif arg_type == "directory":
if not os.path.isdir(arg):
raise argparse.ArgumentError(
None,
"The directory {arg} does not exist.".format(arg=arg))
else:
# Directory exists so return the directory name
return arg
def parse_config_input_output(args=sys.argv):
"""Parse the args using the config_file, input_dir, output_dir pattern
Args:
args: sys.argv
Returns:
The populated namespace object from parser.parse_args().
Raises:
TBD
"""
parser = argparse.ArgumentParser(
description='Process the input files using the given config')
parser.add_argument(
'config_file',
help='Configuration file.',
metavar='FILE', type=extant_file)
parser.add_argument(
'input_dir',
help='Directory containing the input files.',
metavar='DIR', type=extant_dir)
parser.add_argument(
'output_dir',
help='Directory where the output files should be saved.',
metavar='DIR', type=extant_dir)
return parser.parse_args(args[1:])
def parse_config(args=sys.argv):
"""Parse the args using the config_file pattern
Args:
args: sys.argv
Returns:
The populated namespace object from parser.parse_args().
Raises:
TBD
"""
parser = argparse.ArgumentParser(
description='Read in the config file')
parser.add_argument(
'config_file',
help='Configuration file.',
metavar='FILE', type=extant_file)
return parser.parse_args(args[1:])

View file

@ -1,61 +0,0 @@
# Copyright 2013 Dean Gardiner <gardiner91@gmail.com>
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from asio.file import SEEK_ORIGIN_CURRENT
from asio.file_opener import FileOpener
from asio.open_parameters import OpenParameters
from asio.interfaces.posix import PosixInterface
from asio.interfaces.windows import WindowsInterface
import os
class ASIO(object):
platform_handler = None
@classmethod
def get_handler(cls):
if cls.platform_handler:
return cls.platform_handler
if os.name == 'nt':
cls.platform_handler = WindowsInterface
elif os.name == 'posix':
cls.platform_handler = PosixInterface
else:
raise NotImplementedError()
return cls.platform_handler
@classmethod
def open(cls, file_path, opener=True, parameters=None):
"""Open file
:type file_path: str
:param opener: Use FileOpener, for use with the 'with' statement
:type opener: bool
:rtype: asio.file.File
"""
if not parameters:
parameters = OpenParameters()
if opener:
return FileOpener(file_path, parameters)
return ASIO.get_handler().open(
file_path,
parameters=parameters.handlers.get(ASIO.get_handler())
)

View file

@ -1,92 +0,0 @@
# Copyright 2013 Dean Gardiner <gardiner91@gmail.com>
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from io import RawIOBase
import time
DEFAULT_BUFFER_SIZE = 4096
SEEK_ORIGIN_BEGIN = 0
SEEK_ORIGIN_CURRENT = 1
SEEK_ORIGIN_END = 2
class ReadTimeoutError(Exception):
pass
class File(RawIOBase):
platform_handler = None
def __init__(self, *args, **kwargs):
super(File, self).__init__(*args, **kwargs)
def get_handler(self):
"""
:rtype: asio.interfaces.base.Interface
"""
if not self.platform_handler:
raise ValueError()
return self.platform_handler
def get_size(self):
"""Get the current file size
:rtype: int
"""
return self.get_handler().get_size(self)
def get_path(self):
"""Get the path of this file
:rtype: str
"""
return self.get_handler().get_path(self)
def seek(self, offset, origin):
"""Sets a reference point of a file to the given value.
:param offset: The point relative to origin to move
:type offset: int
:param origin: Reference point to seek (SEEK_ORIGIN_BEGIN, SEEK_ORIGIN_CURRENT, SEEK_ORIGIN_END)
:type origin: int
"""
return self.get_handler().seek(self, offset, origin)
def read(self, n=-1):
"""Read up to n bytes from the object and return them.
:type n: int
:rtype: str
"""
return self.get_handler().read(self, n)
def readinto(self, b):
"""Read up to len(b) bytes into bytearray b and return the number of bytes read."""
data = self.read(len(b))
if data is None:
return None
b[:len(data)] = data
return len(data)
def close(self):
"""Close the file handle"""
return self.get_handler().close(self)
def readable(self, *args, **kwargs):
return True

View file

@ -1,21 +0,0 @@
class FileOpener(object):
def __init__(self, file_path, parameters=None):
self.file_path = file_path
self.parameters = parameters
self.file = None
def __enter__(self):
self.file = ASIO.get_handler().open(
self.file_path,
self.parameters.handlers.get(ASIO.get_handler())
)
return self.file
def __exit__(self, exc_type, exc_val, exc_tb):
if not self.file:
return
self.file.close()
self.file = None

View file

@ -1,41 +0,0 @@
# Copyright 2013 Dean Gardiner <gardiner91@gmail.com>
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from asio.file import DEFAULT_BUFFER_SIZE
class Interface(object):
@classmethod
def open(cls, file_path, parameters=None):
raise NotImplementedError()
@classmethod
def get_size(cls, fp):
raise NotImplementedError()
@classmethod
def get_path(cls, fp):
raise NotImplementedError()
@classmethod
def seek(cls, fp, pointer, distance):
raise NotImplementedError()
@classmethod
def read(cls, fp, n=DEFAULT_BUFFER_SIZE):
raise NotImplementedError()
@classmethod
def close(cls, fp):
raise NotImplementedError()

View file

@ -1,123 +0,0 @@
# Copyright 2013 Dean Gardiner <gardiner91@gmail.com>
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from asio.file import File, DEFAULT_BUFFER_SIZE
from asio.interfaces.base import Interface
import sys
import os
if os.name == 'posix':
import select
# fcntl is only required on darwin
if sys.platform == 'darwin':
import fcntl
F_GETPATH = 50
class PosixInterface(Interface):
@classmethod
def open(cls, file_path, parameters=None):
"""
:type file_path: str
:rtype: asio.interfaces.posix.PosixFile
"""
if not parameters:
parameters = {}
if not parameters.get('mode'):
parameters.pop('mode')
if not parameters.get('buffering'):
parameters.pop('buffering')
fd = os.open(file_path, os.O_RDONLY | os.O_NONBLOCK)
return PosixFile(fd)
@classmethod
def get_size(cls, fp):
"""
:type fp: asio.interfaces.posix.PosixFile
:rtype: int
"""
return os.fstat(fp.fd).st_size
@classmethod
def get_path(cls, fp):
"""
:type fp: asio.interfaces.posix.PosixFile
:rtype: int
"""
# readlink /dev/fd fails on darwin, so instead use fcntl F_GETPATH
if sys.platform == 'darwin':
return fcntl.fcntl(fp.fd, F_GETPATH, '\0' * 1024).rstrip('\0')
# Use /proc/self/fd if available
if os.path.lexists("/proc/self/fd/"):
return os.readlink("/proc/self/fd/%s" % fp.fd)
# Fallback to /dev/fd
if os.path.lexists("/dev/fd/"):
return os.readlink("/dev/fd/%s" % fp.fd)
raise NotImplementedError('Environment not supported (fdescfs not mounted?)')
@classmethod
def seek(cls, fp, offset, origin):
"""
:type fp: asio.interfaces.posix.PosixFile
:type offset: int
:type origin: int
"""
os.lseek(fp.fd, offset, origin)
@classmethod
def read(cls, fp, n=DEFAULT_BUFFER_SIZE):
"""
:type fp: asio.interfaces.posix.PosixFile
:type n: int
:rtype: str
"""
r, w, x = select.select([fp.fd], [], [], 5)
if r:
return os.read(fp.fd, n)
return None
@classmethod
def close(cls, fp):
"""
:type fp: asio.interfaces.posix.PosixFile
"""
os.close(fp.fd)
class PosixFile(File):
platform_handler = PosixInterface
def __init__(self, fd, *args, **kwargs):
"""
:type fd: asio.file.File
"""
super(PosixFile, self).__init__(*args, **kwargs)
self.fd = fd
def __str__(self):
return "<asio_posix.PosixFile file: %s>" % self.fd

View file

@ -1,201 +0,0 @@
# Copyright 2013 Dean Gardiner <gardiner91@gmail.com>
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from asio.file import File, DEFAULT_BUFFER_SIZE
from asio.interfaces.base import Interface
import os
NULL = 0
if os.name == 'nt':
from asio.interfaces.windows.interop import WindowsInterop
class WindowsInterface(Interface):
@classmethod
def open(cls, file_path, parameters=None):
"""
:type file_path: str
:rtype: asio.interfaces.windows.WindowsFile
"""
if not parameters:
parameters = {}
return WindowsFile(WindowsInterop.create_file(
file_path,
parameters.get('desired_access', WindowsInterface.GenericAccess.READ),
parameters.get('share_mode', WindowsInterface.ShareMode.ALL),
parameters.get('creation_disposition', WindowsInterface.CreationDisposition.OPEN_EXISTING),
parameters.get('flags_and_attributes', NULL)
))
@classmethod
def get_size(cls, fp):
"""
:type fp: asio.interfaces.windows.WindowsFile
:rtype: int
"""
return WindowsInterop.get_file_size(fp.handle)
@classmethod
def get_path(cls, fp):
"""
:type fp: asio.interfaces.windows.WindowsFile
:rtype: str
"""
if not fp.file_map:
fp.file_map = WindowsInterop.create_file_mapping(fp.handle, WindowsInterface.Protection.READONLY)
if not fp.map_view:
fp.map_view = WindowsInterop.map_view_of_file(fp.file_map, WindowsInterface.FileMapAccess.READ, 1)
file_name = WindowsInterop.get_mapped_file_name(fp.map_view)
return file_name
@classmethod
def seek(cls, fp, offset, origin):
"""
:type fp: asio.interfaces.windows.WindowsFile
:type offset: int
:type origin: int
:rtype: int
"""
return WindowsInterop.set_file_pointer(
fp.handle,
offset,
origin
)
@classmethod
def read(cls, fp, n=DEFAULT_BUFFER_SIZE):
"""
:type fp: asio.interfaces.windows.WindowsFile
:type n: int
:rtype: str
"""
return WindowsInterop.read(fp.handle, n)
@classmethod
def read_into(cls, fp, b):
"""
:type fp: asio.interfaces.windows.WindowsFile
:type b: str
:rtype: int
"""
return WindowsInterop.read_into(fp.handle, b)
@classmethod
def close(cls, fp):
"""
:type fp: asio.interfaces.windows.WindowsFile
:rtype: bool
"""
if fp.map_view:
WindowsInterop.unmap_view_of_file(fp.map_view)
if fp.file_map:
WindowsInterop.close_handle(fp.file_map)
return bool(WindowsInterop.close_handle(fp.handle))
class GenericAccess(object):
READ = 0x80000000
WRITE = 0x40000000
EXECUTE = 0x20000000
ALL = 0x10000000
class ShareMode(object):
READ = 0x00000001
WRITE = 0x00000002
DELETE = 0x00000004
ALL = READ | WRITE | DELETE
class CreationDisposition(object):
CREATE_NEW = 1
CREATE_ALWAYS = 2
OPEN_EXISTING = 3
OPEN_ALWAYS = 4
TRUNCATE_EXISTING = 5
class Attribute(object):
READONLY = 0x00000001
HIDDEN = 0x00000002
SYSTEM = 0x00000004
DIRECTORY = 0x00000010
ARCHIVE = 0x00000020
DEVICE = 0x00000040
NORMAL = 0x00000080
TEMPORARY = 0x00000100
SPARSE_FILE = 0x00000200
REPARSE_POINT = 0x00000400
COMPRESSED = 0x00000800
OFFLINE = 0x00001000
NOT_CONTENT_INDEXED = 0x00002000
ENCRYPTED = 0x00004000
class Flag(object):
WRITE_THROUGH = 0x80000000
OVERLAPPED = 0x40000000
NO_BUFFERING = 0x20000000
RANDOM_ACCESS = 0x10000000
SEQUENTIAL_SCAN = 0x08000000
DELETE_ON_CLOSE = 0x04000000
BACKUP_SEMANTICS = 0x02000000
POSIX_SEMANTICS = 0x01000000
OPEN_REPARSE_POINT = 0x00200000
OPEN_NO_RECALL = 0x00100000
FIRST_PIPE_INSTANCE = 0x00080000
class Protection(object):
NOACCESS = 0x01
READONLY = 0x02
READWRITE = 0x04
WRITECOPY = 0x08
EXECUTE = 0x10
EXECUTE_READ = 0x20,
EXECUTE_READWRITE = 0x40
EXECUTE_WRITECOPY = 0x80
GUARD = 0x100
NOCACHE = 0x200
WRITECOMBINE = 0x400
class FileMapAccess(object):
COPY = 0x0001
WRITE = 0x0002
READ = 0x0004
ALL_ACCESS = 0x001f
EXECUTE = 0x0020
class WindowsFile(File):
platform_handler = WindowsInterface
def __init__(self, handle, *args, **kwargs):
super(WindowsFile, self).__init__(*args, **kwargs)
self.handle = handle
self.file_map = None
self.map_view = None
def readinto(self, b):
return self.get_handler().read_into(self, b)
def __str__(self):
return "<asio_windows.WindowsFile file: %s>" % self.handle

View file

@ -1,230 +0,0 @@
# Copyright 2013 Dean Gardiner <gardiner91@gmail.com>
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ctypes.wintypes import *
from ctypes import *
import logging
log = logging.getLogger(__name__)
CreateFileW = windll.kernel32.CreateFileW
CreateFileW.argtypes = (LPCWSTR, DWORD, DWORD, c_void_p, DWORD, DWORD, HANDLE)
CreateFileW.restype = HANDLE
ReadFile = windll.kernel32.ReadFile
ReadFile.argtypes = (HANDLE, c_void_p, DWORD, POINTER(DWORD), HANDLE)
ReadFile.restype = BOOL
NULL = 0
MAX_PATH = 260
DEFAULT_BUFFER_SIZE = 4096
LPSECURITY_ATTRIBUTES = c_void_p
class WindowsInterop(object):
ri_buffer = None
@classmethod
def create_file(cls, path, desired_access, share_mode, creation_disposition, flags_and_attributes):
h = CreateFileW(
path,
desired_access,
share_mode,
NULL,
creation_disposition,
flags_and_attributes,
NULL
)
error = GetLastError()
if error != 0:
raise Exception('[WindowsASIO.open] "%s"' % FormatError(error))
return h
@classmethod
def read(cls, handle, buf_size=DEFAULT_BUFFER_SIZE):
buf = create_string_buffer(buf_size)
bytes_read = c_ulong(0)
success = ReadFile(handle, buf, buf_size, byref(bytes_read), NULL)
error = GetLastError()
if error:
log.debug('read_file - error: (%s) "%s"', error, FormatError(error))
if not success and error:
raise Exception('[WindowsInterop.read_file] (%s) "%s"' % (error, FormatError(error)))
# Return if we have a valid buffer
if success and bytes_read.value:
return buf.value
return None
@classmethod
def read_into(cls, handle, b):
if cls.ri_buffer is None or len(cls.ri_buffer) < len(b):
cls.ri_buffer = create_string_buffer(len(b))
bytes_read = c_ulong(0)
success = ReadFile(handle, cls.ri_buffer, len(b), byref(bytes_read), NULL)
bytes_read = int(bytes_read.value)
b[:bytes_read] = cls.ri_buffer[:bytes_read]
error = GetLastError()
if not success and error:
raise Exception('[WindowsInterop.read_file] (%s) "%s"' % (error, FormatError(error)))
# Return if we have a valid buffer
if success and bytes_read:
return bytes_read
return None
@classmethod
def set_file_pointer(cls, handle, distance, method):
pos_high = DWORD(NULL)
result = windll.kernel32.SetFilePointer(
handle,
c_ulong(distance),
byref(pos_high),
DWORD(method)
)
if result == -1:
raise Exception('[WindowsASIO.seek] INVALID_SET_FILE_POINTER: "%s"' % FormatError(GetLastError()))
return result
@classmethod
def get_file_size(cls, handle):
return windll.kernel32.GetFileSize(
handle,
DWORD(NULL)
)
@classmethod
def close_handle(cls, handle):
return windll.kernel32.CloseHandle(handle)
@classmethod
def create_file_mapping(cls, handle, protect, maximum_size_high=0, maximum_size_low=1):
return HANDLE(windll.kernel32.CreateFileMappingW(
handle,
LPSECURITY_ATTRIBUTES(NULL),
DWORD(protect),
DWORD(maximum_size_high),
DWORD(maximum_size_low),
LPCSTR(NULL)
))
@classmethod
def map_view_of_file(cls, map_handle, desired_access, num_bytes, file_offset_high=0, file_offset_low=0):
return HANDLE(windll.kernel32.MapViewOfFile(
map_handle,
DWORD(desired_access),
DWORD(file_offset_high),
DWORD(file_offset_low),
num_bytes
))
@classmethod
def unmap_view_of_file(cls, view_handle):
return windll.kernel32.UnmapViewOfFile(view_handle)
@classmethod
def get_mapped_file_name(cls, view_handle, translate_device_name=True):
buf = create_string_buffer(MAX_PATH + 1)
result = windll.psapi.GetMappedFileNameW(
cls.get_current_process(),
view_handle,
buf,
MAX_PATH
)
# Raise exception on error
error = GetLastError()
if result == 0:
raise Exception(FormatError(error))
# Retrieve a clean file name (skipping over NUL bytes)
file_name = cls.clean_buffer_value(buf)
# If we are not translating the device name return here
if not translate_device_name:
return file_name
drives = cls.get_logical_drive_strings()
# Find the drive matching the file_name device name
translated = False
for drive in drives:
device_name = cls.query_dos_device(drive)
if file_name.startswith(device_name):
file_name = drive + file_name[len(device_name):]
translated = True
break
if not translated:
raise Exception('Unable to translate device name')
return file_name
@classmethod
def get_logical_drive_strings(cls, buf_size=512):
buf = create_string_buffer(buf_size)
result = windll.kernel32.GetLogicalDriveStringsW(buf_size, buf)
error = GetLastError()
if result == 0:
raise Exception(FormatError(error))
drive_strings = cls.clean_buffer_value(buf)
return [dr for dr in drive_strings.split('\\') if dr != '']
@classmethod
def query_dos_device(cls, drive, buf_size=MAX_PATH):
buf = create_string_buffer(buf_size)
result = windll.kernel32.QueryDosDeviceA(
drive,
buf,
buf_size
)
return cls.clean_buffer_value(buf)
@classmethod
def get_current_process(cls):
return HANDLE(windll.kernel32.GetCurrentProcess())
@classmethod
def clean_buffer_value(cls, buf):
value = ""
for ch in buf.raw:
if ord(ch) != 0:
value += ch
return value

View file

@ -1,47 +0,0 @@
from asio.interfaces.posix import PosixInterface
from asio.interfaces.windows import WindowsInterface
class OpenParameters(object):
def __init__(self):
self.handlers = {}
# Update handler_parameters with defaults
self.posix()
self.windows()
def posix(self, mode=None, buffering=None):
"""
:type mode: str
:type buffering: int
"""
self.handlers.update({PosixInterface: {
'mode': mode,
'buffering': buffering
}})
def windows(self, desired_access=WindowsInterface.GenericAccess.READ,
share_mode=WindowsInterface.ShareMode.ALL,
creation_disposition=WindowsInterface.CreationDisposition.OPEN_EXISTING,
flags_and_attributes=0):
"""
:param desired_access: WindowsInterface.DesiredAccess
:type desired_access: int
:param share_mode: WindowsInterface.ShareMode
:type share_mode: int
:param creation_disposition: WindowsInterface.CreationDisposition
:type creation_disposition: int
:param flags_and_attributes: WindowsInterface.Attribute, WindowsInterface.Flag
:type flags_and_attributes: int
"""
self.handlers.update({WindowsInterface: {
'desired_access': desired_access,
'share_mode': share_mode,
'creation_disposition': creation_disposition,
'flags_and_attributes': flags_and_attributes
}})

View file

@ -2,20 +2,16 @@
:author:
Amine SEHILI <amine.sehili@gmail.com>
2015-2016
2015-2021
:License:
This package is published under GNU GPL Version 3.
This package is published under the MIT license.
"""
from __future__ import absolute_import
from .core import *
from .io import *
from .util import *
from . import dataset
from .exceptions import *
__version__ = "0.1.5"
__version__ = "0.2.0"

File diff suppressed because it is too large Load diff

126
libs/auditok/cmdline_util.py Executable file
View file

@ -0,0 +1,126 @@
import sys
import logging
from collections import namedtuple
from . import workers
from .util import AudioDataSource
from .io import player_for
_AUDITOK_LOGGER = "AUDITOK_LOGGER"
KeywordArguments = namedtuple(
"KeywordArguments", ["io", "split", "miscellaneous"]
)
def make_kwargs(args_ns):
if args_ns.save_stream is None:
record = args_ns.plot or (args_ns.save_image is not None)
else:
record = False
try:
use_channel = int(args_ns.use_channel)
except (ValueError, TypeError):
use_channel = args_ns.use_channel
io_kwargs = {
"input": args_ns.input,
"audio_format": args_ns.input_format,
"max_read": args_ns.max_read,
"block_dur": args_ns.analysis_window,
"sampling_rate": args_ns.sampling_rate,
"sample_width": args_ns.sample_width,
"channels": args_ns.channels,
"use_channel": use_channel,
"save_stream": args_ns.save_stream,
"save_detections_as": args_ns.save_detections_as,
"export_format": args_ns.output_format,
"large_file": args_ns.large_file,
"frames_per_buffer": args_ns.frame_per_buffer,
"input_device_index": args_ns.input_device_index,
"record": record,
}
split_kwargs = {
"min_dur": args_ns.min_duration,
"max_dur": args_ns.max_duration,
"max_silence": args_ns.max_silence,
"drop_trailing_silence": args_ns.drop_trailing_silence,
"strict_min_dur": args_ns.strict_min_duration,
"energy_threshold": args_ns.energy_threshold,
}
miscellaneous = {
"echo": args_ns.echo,
"progress_bar": args_ns.progress_bar,
"command": args_ns.command,
"quiet": args_ns.quiet,
"printf": args_ns.printf,
"time_format": args_ns.time_format,
"timestamp_format": args_ns.timestamp_format,
}
return KeywordArguments(io_kwargs, split_kwargs, miscellaneous)
def make_logger(stderr=False, file=None, name=_AUDITOK_LOGGER):
if not stderr and file is None:
return None
logger = logging.getLogger(name)
logger.setLevel(logging.INFO)
if stderr:
handler = logging.StreamHandler(sys.stderr)
handler.setLevel(logging.INFO)
logger.addHandler(handler)
if file is not None:
handler = logging.FileHandler(file, "w")
fmt = logging.Formatter("[%(asctime)s] | %(message)s")
handler.setFormatter(fmt)
handler.setLevel(logging.INFO)
logger.addHandler(handler)
return logger
def initialize_workers(logger=None, **kwargs):
observers = []
reader = AudioDataSource(source=kwargs["input"], **kwargs)
if kwargs["save_stream"] is not None:
reader = workers.StreamSaverWorker(
reader,
filename=kwargs["save_stream"],
export_format=kwargs["export_format"],
)
reader.start()
if kwargs["save_detections_as"] is not None:
worker = workers.RegionSaverWorker(
kwargs["save_detections_as"],
kwargs["export_format"],
logger=logger,
)
observers.append(worker)
if kwargs["echo"]:
player = player_for(reader)
worker = workers.PlayerWorker(
player, progress_bar=kwargs["progress_bar"], logger=logger
)
observers.append(worker)
if kwargs["command"] is not None:
worker = workers.CommandLineWorker(
command=kwargs["command"], logger=logger
)
observers.append(worker)
if not kwargs["quiet"]:
print_format = (
kwargs["printf"]
.replace("\\n", "\n")
.replace("\\t", "\t")
.replace("\\r", "\r")
)
worker = workers.PrintWorker(
print_format, kwargs["time_format"], kwargs["timestamp_format"]
)
observers.append(worker)
return reader, observers

File diff suppressed because it is too large Load diff

View file

@ -1,19 +1,31 @@
"""
This module contains links to audio files you can use for test purposes.
This module contains links to audio files that can be used for test purposes.
.. autosummary::
:toctree: generated/
one_to_six_arabic_16000_mono_bc_noise
was_der_mensch_saet_mono_44100_lead_trail_silence
"""
import os
__all__ = ["one_to_six_arabic_16000_mono_bc_noise", "was_der_mensch_saet_mono_44100_lead_trail_silence"]
__all__ = [
"one_to_six_arabic_16000_mono_bc_noise",
"was_der_mensch_saet_mono_44100_lead_trail_silence",
]
_current_dir = os.path.dirname(os.path.realpath(__file__))
one_to_six_arabic_16000_mono_bc_noise = "{cd}{sep}data{sep}1to6arabic_\
16000_mono_bc_noise.wav".format(cd=_current_dir, sep=os.path.sep)
16000_mono_bc_noise.wav".format(
cd=_current_dir, sep=os.path.sep
)
"""A wave file that contains a pronunciation of Arabic numbers from 1 to 6"""
was_der_mensch_saet_mono_44100_lead_trail_silence = "{cd}{sep}data{sep}was_\
der_mensch_saet_das_wird_er_vielfach_ernten_44100Hz_mono_lead_trail_\
silence.wav".format(cd=_current_dir, sep=os.path.sep)
""" A wave file that contains a sentence between long leading and trailing periods of silence"""
silence.wav".format(
cd=_current_dir, sep=os.path.sep
)
"""A wave file that contains a sentence with a long leading and trailing silence"""

View file

@ -1,9 +1,41 @@
"""
November 2015
@author: Amine SEHILI <amine.sehili@gmail.com>
"""
class DuplicateArgument(Exception):
pass
class TooSamllBlockDuration(ValueError):
"""Raised when block_dur results in a block_size smaller than one sample."""
def __init__(self, message, block_dur, sampling_rate):
self.block_dur = block_dur
self.sampling_rate = sampling_rate
super(TooSamllBlockDuration, self).__init__(message)
class TimeFormatError(Exception):
"""Raised when a duration formatting directive is unknown."""
class EndOfProcessing(Exception):
"""Raised within command line script's main function to jump to
postprocessing code."""
class AudioIOError(Exception):
"""Raised when a compressed audio file cannot be loaded or when trying
to read from a not yet open AudioSource"""
class AudioParameterError(AudioIOError):
"""Raised when one audio parameter is missing when loading raw data or
saving data to a format other than raw. Also raised when an audio
parameter has a wrong value."""
class AudioEncodingError(Exception):
"""Raised if audio data can not be encoded in the provided format"""
class AudioEncodingWarning(RuntimeWarning):
"""Raised if audio data can not be encoded in the provided format
but saved as wav.
"""

File diff suppressed because it is too large Load diff

150
libs/auditok/plotting.py Executable file
View file

@ -0,0 +1,150 @@
import matplotlib.pyplot as plt
import numpy as np
AUDITOK_PLOT_THEME = {
"figure": {"facecolor": "#482a36", "alpha": 0.2},
"plot": {"facecolor": "#282a36"},
"energy_threshold": {
"color": "#e31f8f",
"linestyle": "--",
"linewidth": 1,
},
"signal": {"color": "#40d970", "linestyle": "-", "linewidth": 1},
"detections": {
"facecolor": "#777777",
"edgecolor": "#ff8c1a",
"linewidth": 1,
"alpha": 0.75,
},
}
def _make_time_axis(nb_samples, sampling_rate):
sample_duration = 1 / sampling_rate
x = np.linspace(0, sample_duration * (nb_samples - 1), nb_samples)
return x
def _plot_line(x, y, theme, xlabel=None, ylabel=None, **kwargs):
color = theme.get("color", theme.get("c"))
ls = theme.get("linestyle", theme.get("ls"))
lw = theme.get("linewidth", theme.get("lw"))
plt.plot(x, y, c=color, ls=ls, lw=lw, **kwargs)
plt.xlabel(xlabel, fontsize=8)
plt.ylabel(ylabel, fontsize=8)
def _plot_detections(subplot, detections, theme):
fc = theme.get("facecolor", theme.get("fc"))
ec = theme.get("edgecolor", theme.get("ec"))
ls = theme.get("linestyle", theme.get("ls"))
lw = theme.get("linewidth", theme.get("lw"))
alpha = theme.get("alpha")
for (start, end) in detections:
subplot.axvspan(start, end, fc=fc, ec=ec, ls=ls, lw=lw, alpha=alpha)
def plot(
audio_region,
scale_signal=True,
detections=None,
energy_threshold=None,
show=True,
figsize=None,
save_as=None,
dpi=120,
theme="auditok",
):
y = np.asarray(audio_region)
if len(y.shape) == 1:
y = y.reshape(1, -1)
nb_subplots, nb_samples = y.shape
sampling_rate = audio_region.sampling_rate
time_axis = _make_time_axis(nb_samples, sampling_rate)
if energy_threshold is not None:
eth_log10 = energy_threshold * np.log(10) / 10
amplitude_threshold = np.sqrt(np.exp(eth_log10))
else:
amplitude_threshold = None
if detections is None:
detections = []
else:
# End of detection corresponds to the end of the last sample but
# to stay compatible with the time axis of signal plotting we want end
# of detection to correspond to the *start* of the that last sample.
detections = [
(start, end - (1 / sampling_rate)) for (start, end) in detections
]
if theme == "auditok":
theme = AUDITOK_PLOT_THEME
fig = plt.figure(figsize=figsize, dpi=dpi)
fig_theme = theme.get("figure", theme.get("fig", {}))
fig_fc = fig_theme.get("facecolor", fig_theme.get("ffc"))
fig_alpha = fig_theme.get("alpha", 1)
fig.patch.set_facecolor(fig_fc)
fig.patch.set_alpha(fig_alpha)
plot_theme = theme.get("plot", {})
plot_fc = plot_theme.get("facecolor", plot_theme.get("pfc"))
if nb_subplots > 2 and nb_subplots % 2 == 0:
nb_rows = nb_subplots // 2
nb_columns = 2
else:
nb_rows = nb_subplots
nb_columns = 1
for sid, samples in enumerate(y, 1):
ax = fig.add_subplot(nb_rows, nb_columns, sid)
ax.set_facecolor(plot_fc)
if scale_signal:
std = samples.std()
if std > 0:
mean = samples.mean()
std = samples.std()
samples = (samples - mean) / std
max_ = samples.max()
plt.ylim(-1.5 * max_, 1.5 * max_)
if amplitude_threshold is not None:
if scale_signal and std > 0:
amp_th = (amplitude_threshold - mean) / std
else:
amp_th = amplitude_threshold
eth_theme = theme.get("energy_threshold", theme.get("eth", {}))
_plot_line(
[time_axis[0], time_axis[-1]],
[amp_th] * 2,
eth_theme,
label="Detection threshold",
)
if sid == 1:
legend = plt.legend(
["Detection threshold"],
facecolor=fig_fc,
framealpha=0.1,
bbox_to_anchor=(0.0, 1.15, 1.0, 0.102),
loc=2,
)
legend = plt.gca().add_artist(legend)
signal_theme = theme.get("signal", {})
_plot_line(
time_axis,
samples,
signal_theme,
xlabel="Time (seconds)",
ylabel="Signal{}".format(" (scaled)" if scale_signal else ""),
)
detections_theme = theme.get("detections", {})
_plot_detections(ax, detections, detections_theme)
plt.title("Channel {}".format(sid), fontsize=10)
plt.xticks(fontsize=8)
plt.yticks(fontsize=8)
plt.tight_layout()
if save_as is not None:
plt.savefig(save_as, dpi=dpi)
if show:
plt.show()

179
libs/auditok/signal.py Normal file
View file

@ -0,0 +1,179 @@
"""
Module for basic audio signal processing and array operations.
.. autosummary::
:toctree: generated/
to_array
extract_single_channel
compute_average_channel
compute_average_channel_stereo
separate_channels
calculate_energy_single_channel
calculate_energy_multichannel
"""
from array import array as array_
import audioop
import math
FORMAT = {1: "b", 2: "h", 4: "i"}
_EPSILON = 1e-10
def to_array(data, sample_width, channels):
"""Extract individual channels of audio data and return a list of arrays of
numeric samples. This will always return a list of `array.array` objects
(one per channel) even if audio data is mono.
Parameters
----------
data : bytes
raw audio data.
sample_width : int
size in bytes of one audio sample (one channel considered).
Returns
-------
samples_arrays : list
list of arrays of audio samples.
"""
fmt = FORMAT[sample_width]
if channels == 1:
return [array_(fmt, data)]
return separate_channels(data, fmt, channels)
def extract_single_channel(data, fmt, channels, selected):
samples = array_(fmt, data)
return samples[selected::channels]
def compute_average_channel(data, fmt, channels):
"""
Compute and return average channel of multi-channel audio data. If the
number of channels is 2, use :func:`compute_average_channel_stereo` (much
faster). This function uses satandard `array` module to convert `bytes` data
into an array of numeric values.
Parameters
----------
data : bytes
multi-channel audio data to mix down.
fmt : str
format (single character) to pass to `array.array` to convert `data`
into an array of samples. This should be "b" if audio data's sample width
is 1, "h" if it's 2 and "i" if it's 4.
channels : int
number of channels of audio data.
Returns
-------
mono_audio : bytes
mixed down audio data.
"""
all_channels = array_(fmt, data)
mono_channels = [
array_(fmt, all_channels[ch::channels]) for ch in range(channels)
]
avg_arr = array_(
fmt,
(round(sum(samples) / channels) for samples in zip(*mono_channels)),
)
return avg_arr
def compute_average_channel_stereo(data, sample_width):
"""Compute and return average channel of stereo audio data. This function
should be used when the number of channels is exactly 2 because in that
case we can use standard `audioop` module which *much* faster then calling
:func:`compute_average_channel`.
Parameters
----------
data : bytes
2-channel audio data to mix down.
sample_width : int
size in bytes of one audio sample (one channel considered).
Returns
-------
mono_audio : bytes
mixed down audio data.
"""
fmt = FORMAT[sample_width]
arr = array_(fmt, audioop.tomono(data, sample_width, 0.5, 0.5))
return arr
def separate_channels(data, fmt, channels):
"""Create a list of arrays of audio samples (`array.array` objects), one for
each channel.
Parameters
----------
data : bytes
multi-channel audio data to mix down.
fmt : str
format (single character) to pass to `array.array` to convert `data`
into an array of samples. This should be "b" if audio data's sample width
is 1, "h" if it's 2 and "i" if it's 4.
channels : int
number of channels of audio data.
Returns
-------
channels_arr : list
list of audio channels, each as a standard `array.array`.
"""
all_channels = array_(fmt, data)
mono_channels = [
array_(fmt, all_channels[ch::channels]) for ch in range(channels)
]
return mono_channels
def calculate_energy_single_channel(data, sample_width):
"""Calculate the energy of mono audio data. Energy is computed as:
.. math:: energy = 20 \log(\sqrt({1}/{N}\sum_{i}^{N}{a_i}^2)) % # noqa: W605
where `a_i` is the i-th audio sample and `N` is the number of audio samples
in data.
Parameters
----------
data : bytes
single-channel audio data.
sample_width : int
size in bytes of one audio sample.
Returns
-------
energy : float
energy of audio signal.
"""
energy_sqrt = max(audioop.rms(data, sample_width), _EPSILON)
return 20 * math.log10(energy_sqrt)
def calculate_energy_multichannel(x, sample_width, aggregation_fn=max):
"""Calculate the energy of multi-channel audio data. Energy is calculated
channel-wise. An aggregation function is applied to the resulting energies
(default: `max`). Also see :func:`calculate_energy_single_channel`.
Parameters
----------
data : bytes
single-channel audio data.
sample_width : int
size in bytes of one audio sample (one channel considered).
aggregation_fn : callable, default: max
aggregation function to apply to the resulting per-channel energies.
Returns
-------
energy : float
aggregated energy of multi-channel audio signal.
"""
energies = (calculate_energy_single_channel(xi, sample_width) for xi in x)
return aggregation_fn(energies)

View file

@ -0,0 +1,30 @@
import numpy as np
from .signal import (
compute_average_channel_stereo,
calculate_energy_single_channel,
calculate_energy_multichannel,
)
FORMAT = {1: np.int8, 2: np.int16, 4: np.int32}
def to_array(data, sample_width, channels):
fmt = FORMAT[sample_width]
if channels == 1:
return np.frombuffer(data, dtype=fmt).astype(np.float64)
return separate_channels(data, fmt, channels).astype(np.float64)
def extract_single_channel(data, fmt, channels, selected):
samples = np.frombuffer(data, dtype=fmt)
return np.asanyarray(samples[selected::channels], order="C")
def compute_average_channel(data, fmt, channels):
array = np.frombuffer(data, dtype=fmt).astype(np.float64)
return array.reshape(-1, channels).mean(axis=1).round().astype(fmt)
def separate_channels(data, fmt, channels):
array = np.frombuffer(data, dtype=fmt)
return np.asanyarray(array.reshape(-1, channels).T, order="C")

File diff suppressed because it is too large Load diff

427
libs/auditok/workers.py Executable file
View file

@ -0,0 +1,427 @@
import os
import sys
from tempfile import NamedTemporaryFile
from abc import ABCMeta, abstractmethod
from threading import Thread
from datetime import datetime, timedelta
from collections import namedtuple
import wave
import subprocess
from queue import Queue, Empty
from .io import _guess_audio_format
from .util import AudioDataSource, make_duration_formatter
from .core import split
from .exceptions import (
EndOfProcessing,
AudioEncodingError,
AudioEncodingWarning,
)
_STOP_PROCESSING = "STOP_PROCESSING"
_Detection = namedtuple("_Detection", "id start end duration")
def _run_subprocess(command):
try:
with subprocess.Popen(
command,
stdin=open(os.devnull, "rb"),
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
) as proc:
stdout, stderr = proc.communicate()
return proc.returncode, stdout, stderr
except Exception:
err_msg = "Couldn't export audio using command: '{}'".format(command)
raise AudioEncodingError(err_msg)
class Worker(Thread, metaclass=ABCMeta):
def __init__(self, timeout=0.5, logger=None):
self._timeout = timeout
self._logger = logger
self._inbox = Queue()
Thread.__init__(self)
def run(self):
while True:
message = self._get_message()
if message == _STOP_PROCESSING:
break
if message is not None:
self._process_message(message)
self._post_process()
@abstractmethod
def _process_message(self, message):
"""Process incoming messages"""
def _post_process(self):
pass
def _log(self, message):
self._logger.info(message)
def _stop_requested(self):
try:
message = self._inbox.get_nowait()
if message == _STOP_PROCESSING:
return True
except Empty:
return False
def stop(self):
self.send(_STOP_PROCESSING)
self.join()
def send(self, message):
self._inbox.put(message)
def _get_message(self):
try:
message = self._inbox.get(timeout=self._timeout)
return message
except Empty:
return None
class TokenizerWorker(Worker, AudioDataSource):
def __init__(self, reader, observers=None, logger=None, **kwargs):
self._observers = observers if observers is not None else []
self._reader = reader
self._audio_region_gen = split(self, **kwargs)
self._detections = []
self._log_format = "[DET]: Detection {0.id} (start: {0.start:.3f}, "
self._log_format += "end: {0.end:.3f}, duration: {0.duration:.3f})"
Worker.__init__(self, timeout=0.2, logger=logger)
def _process_message(self):
pass
@property
def detections(self):
return self._detections
def _notify_observers(self, message):
for observer in self._observers:
observer.send(message)
def run(self):
self._reader.open()
start_processing_timestamp = datetime.now()
for _id, audio_region in enumerate(self._audio_region_gen, start=1):
timestamp = start_processing_timestamp + timedelta(
seconds=audio_region.meta.start
)
audio_region.meta.timestamp = timestamp
detection = _Detection(
_id,
audio_region.meta.start,
audio_region.meta.end,
audio_region.duration,
)
self._detections.append(detection)
if self._logger is not None:
message = self._log_format.format(detection)
self._log(message)
self._notify_observers((_id, audio_region))
self._notify_observers(_STOP_PROCESSING)
self._reader.close()
def start_all(self):
for observer in self._observers:
observer.start()
self.start()
def stop_all(self):
self.stop()
for observer in self._observers:
observer.stop()
self._reader.close()
def read(self):
if self._stop_requested():
return None
else:
return self._reader.read()
def __getattr__(self, name):
return getattr(self._reader, name)
class StreamSaverWorker(Worker):
def __init__(
self,
audio_reader,
filename,
export_format=None,
cache_size_sec=0.5,
timeout=0.2,
):
self._reader = audio_reader
sample_size_bytes = self._reader.sw * self._reader.ch
self._cache_size = cache_size_sec * self._reader.sr * sample_size_bytes
self._output_filename = filename
self._export_format = _guess_audio_format(export_format, filename)
if self._export_format is None:
self._export_format = "wav"
self._init_output_stream()
self._exported = False
self._cache = []
self._total_cached = 0
Worker.__init__(self, timeout=timeout)
def _get_non_existent_filename(self):
filename = self._output_filename + ".wav"
i = 0
while os.path.exists(filename):
i += 1
filename = self._output_filename + "({}).wav".format(i)
return filename
def _init_output_stream(self):
if self._export_format != "wav":
self._tmp_output_filename = self._get_non_existent_filename()
else:
self._tmp_output_filename = self._output_filename
self._wfp = wave.open(self._tmp_output_filename, "wb")
self._wfp.setframerate(self._reader.sr)
self._wfp.setsampwidth(self._reader.sw)
self._wfp.setnchannels(self._reader.ch)
@property
def sr(self):
return self._reader.sampling_rate
@property
def sw(self):
return self._reader.sample_width
@property
def ch(self):
return self._reader.channels
def __del__(self):
self._post_process()
if (
(self._tmp_output_filename != self._output_filename)
and self._exported
and os.path.exists(self._tmp_output_filename)
):
os.remove(self._tmp_output_filename)
def _process_message(self, data):
self._cache.append(data)
self._total_cached += len(data)
if self._total_cached >= self._cache_size:
self._write_cached_data()
def _post_process(self):
while True:
try:
data = self._inbox.get_nowait()
if data != _STOP_PROCESSING:
self._cache.append(data)
self._total_cached += len(data)
except Empty:
break
self._write_cached_data()
self._wfp.close()
def _write_cached_data(self):
if self._cache:
data = b"".join(self._cache)
self._wfp.writeframes(data)
self._cache = []
self._total_cached = 0
def open(self):
self._reader.open()
def close(self):
self._reader.close()
self.stop()
def rewind(self):
# ensure compatibility with AudioDataSource with record=True
pass
@property
def data(self):
with wave.open(self._tmp_output_filename, "rb") as wfp:
return wfp.readframes(-1)
def save_stream(self):
if self._exported:
return self._output_filename
if self._export_format in ("raw", "wav"):
if self._export_format == "raw":
self._export_raw()
self._exported = True
return self._output_filename
try:
self._export_with_ffmpeg_or_avconv()
except AudioEncodingError:
try:
self._export_with_sox()
except AudioEncodingError:
warn_msg = "Couldn't save audio data in the desired format "
warn_msg += "'{}'. Either none of 'ffmpeg', 'avconv' or 'sox' "
warn_msg += "is installed or this format is not recognized.\n"
warn_msg += "Audio file was saved as '{}'"
raise AudioEncodingWarning(
warn_msg.format(
self._export_format, self._tmp_output_filename
)
)
finally:
self._exported = True
return self._output_filename
def _export_raw(self):
with open(self._output_filename, "wb") as wfp:
wfp.write(self.data)
def _export_with_ffmpeg_or_avconv(self):
command = [
"-y",
"-f",
"wav",
"-i",
self._tmp_output_filename,
"-f",
self._export_format,
self._output_filename,
]
returncode, stdout, stderr = _run_subprocess(["ffmpeg"] + command)
if returncode != 0:
returncode, stdout, stderr = _run_subprocess(["avconv"] + command)
if returncode != 0:
raise AudioEncodingError(stderr)
return stdout, stderr
def _export_with_sox(self):
command = [
"sox",
"-t",
"wav",
self._tmp_output_filename,
self._output_filename,
]
returncode, stdout, stderr = _run_subprocess(command)
if returncode != 0:
raise AudioEncodingError(stderr)
return stdout, stderr
def close_output(self):
self._wfp.close()
def read(self):
data = self._reader.read()
if data is not None:
self.send(data)
else:
self.send(_STOP_PROCESSING)
return data
def __getattr__(self, name):
if name == "data":
return self.data
return getattr(self._reader, name)
class PlayerWorker(Worker):
def __init__(self, player, progress_bar=False, timeout=0.2, logger=None):
self._player = player
self._progress_bar = progress_bar
self._log_format = "[PLAY]: Detection {id} played"
Worker.__init__(self, timeout=timeout, logger=logger)
def _process_message(self, message):
_id, audio_region = message
if self._logger is not None:
message = self._log_format.format(id=_id)
self._log(message)
audio_region.play(
player=self._player, progress_bar=self._progress_bar, leave=False
)
class RegionSaverWorker(Worker):
def __init__(
self,
filename_format,
audio_format=None,
timeout=0.2,
logger=None,
**audio_parameters
):
self._filename_format = filename_format
self._audio_format = audio_format
self._audio_parameters = audio_parameters
self._debug_format = "[SAVE]: Detection {id} saved as '{filename}'"
Worker.__init__(self, timeout=timeout, logger=logger)
def _process_message(self, message):
_id, audio_region = message
filename = self._filename_format.format(
id=_id,
start=audio_region.meta.start,
end=audio_region.meta.end,
duration=audio_region.duration,
)
filename = audio_region.save(
filename, self._audio_format, **self._audio_parameters
)
if self._logger:
message = self._debug_format.format(id=_id, filename=filename)
self._log(message)
class CommandLineWorker(Worker):
def __init__(self, command, timeout=0.2, logger=None):
self._command = command
Worker.__init__(self, timeout=timeout, logger=logger)
self._debug_format = "[COMMAND]: Detection {id} command: '{command}'"
def _process_message(self, message):
_id, audio_region = message
with NamedTemporaryFile(delete=False) as file:
filename = audio_region.save(file.name, audio_format="wav")
command = self._command.format(file=filename)
os.system(command)
if self._logger is not None:
message = self._debug_format.format(id=_id, command=command)
self._log(message)
class PrintWorker(Worker):
def __init__(
self,
print_format="{start} {end}",
time_format="%S",
timestamp_format="%Y/%m/%d %H:%M:%S.%f",
timeout=0.2,
):
self._print_format = print_format
self._format_time = make_duration_formatter(time_format)
self._timestamp_format = timestamp_format
self.detections = []
Worker.__init__(self, timeout=timeout)
def _process_message(self, message):
_id, audio_region = message
timestamp = audio_region.meta.timestamp
timestamp = timestamp.strftime(self._timestamp_format)
text = self._print_format.format(
id=_id,
start=self._format_time(audio_region.meta.start),
end=self._format_time(audio_region.meta.end),
duration=self._format_time(audio_region.duration),
timestamp=timestamp,
)
print(text)

View file

@ -1,11 +1 @@
# A Python "namespace package" http://www.python.org/dev/peps/pep-0382/
# This always goes inside of a namespace package's __init__.py
from pkgutil import extend_path
__path__ = extend_path(__path__, __name__)
try:
import pkg_resources
pkg_resources.declare_namespace(__name__)
except ImportError:
pass
__path__ = __import__('pkgutil').extend_path(__path__, __name__) # type: ignore

View file

@ -4,14 +4,16 @@ import functools
from collections import namedtuple
from threading import RLock
_CacheInfo = namedtuple("CacheInfo", ["hits", "misses", "maxsize", "currsize"])
_CacheInfo = namedtuple("_CacheInfo", ["hits", "misses", "maxsize", "currsize"])
@functools.wraps(functools.update_wrapper)
def update_wrapper(wrapper,
wrapped,
assigned = functools.WRAPPER_ASSIGNMENTS,
updated = functools.WRAPPER_UPDATES):
def update_wrapper(
wrapper,
wrapped,
assigned=functools.WRAPPER_ASSIGNMENTS,
updated=functools.WRAPPER_UPDATES,
):
"""
Patch two bugs in functools.update_wrapper.
"""
@ -34,10 +36,17 @@ class _HashedSeq(list):
return self.hashvalue
def _make_key(args, kwds, typed,
kwd_mark=(object(),),
fasttypes=set([int, str, frozenset, type(None)]),
sorted=sorted, tuple=tuple, type=type, len=len):
def _make_key(
args,
kwds,
typed,
kwd_mark=(object(),),
fasttypes=set([int, str, frozenset, type(None)]),
sorted=sorted,
tuple=tuple,
type=type,
len=len,
):
'Make a cache key from optionally typed positional and keyword arguments'
key = args
if kwds:
@ -54,7 +63,7 @@ def _make_key(args, kwds, typed,
return _HashedSeq(key)
def lru_cache(maxsize=100, typed=False):
def lru_cache(maxsize=100, typed=False): # noqa: C901
"""Least-recently-used cache decorator.
If *maxsize* is set to None, the LRU features are disabled and the cache
@ -82,16 +91,16 @@ def lru_cache(maxsize=100, typed=False):
def decorating_function(user_function):
cache = dict()
stats = [0, 0] # make statistics updateable non-locally
HITS, MISSES = 0, 1 # names for the stats fields
stats = [0, 0] # make statistics updateable non-locally
HITS, MISSES = 0, 1 # names for the stats fields
make_key = _make_key
cache_get = cache.get # bound method to lookup key or return None
_len = len # localize the global len() function
lock = RLock() # because linkedlist updates aren't threadsafe
root = [] # root of the circular doubly linked list
root[:] = [root, root, None, None] # initialize by pointing to self
nonlocal_root = [root] # make updateable non-locally
PREV, NEXT, KEY, RESULT = 0, 1, 2, 3 # names for the link fields
cache_get = cache.get # bound method to lookup key or return None
_len = len # localize the global len() function
lock = RLock() # because linkedlist updates aren't threadsafe
root = [] # root of the circular doubly linked list
root[:] = [root, root, None, None] # initialize by pointing to self
nonlocal_root = [root] # make updateable non-locally
PREV, NEXT, KEY, RESULT = 0, 1, 2, 3 # names for the link fields
if maxsize == 0:
@ -106,7 +115,9 @@ def lru_cache(maxsize=100, typed=False):
def wrapper(*args, **kwds):
# simple caching without ordering or size limit
key = make_key(args, kwds, typed)
result = cache_get(key, root) # root used here as a unique not-found sentinel
result = cache_get(
key, root
) # root used here as a unique not-found sentinel
if result is not root:
stats[HITS] += 1
return result
@ -123,8 +134,9 @@ def lru_cache(maxsize=100, typed=False):
with lock:
link = cache_get(key)
if link is not None:
# record recent use of the key by moving it to the front of the list
root, = nonlocal_root
# record recent use of the key by moving it
# to the front of the list
(root,) = nonlocal_root
link_prev, link_next, key, result = link
link_prev[NEXT] = link_next
link_next[PREV] = link_prev
@ -136,7 +148,7 @@ def lru_cache(maxsize=100, typed=False):
return result
result = user_function(*args, **kwds)
with lock:
root, = nonlocal_root
(root,) = nonlocal_root
if key in cache:
# getting here means that this same key was added to the
# cache while the lock was released. since the link

View file

@ -0,0 +1,49 @@
__all__ = [
"ZoneInfo",
"reset_tzpath",
"available_timezones",
"TZPATH",
"ZoneInfoNotFoundError",
"InvalidTZPathWarning",
]
import sys
from . import _tzpath
from ._common import ZoneInfoNotFoundError
from ._version import __version__
try:
from ._czoneinfo import ZoneInfo
except ImportError: # pragma: nocover
from ._zoneinfo import ZoneInfo
reset_tzpath = _tzpath.reset_tzpath
available_timezones = _tzpath.available_timezones
InvalidTZPathWarning = _tzpath.InvalidTZPathWarning
if sys.version_info < (3, 7):
# Module-level __getattr__ was added in Python 3.7, so instead of lazily
# populating TZPATH on every access, we will register a callback with
# reset_tzpath to update the top-level tuple.
TZPATH = _tzpath.TZPATH
def _tzpath_callback(new_tzpath):
global TZPATH
TZPATH = new_tzpath
_tzpath.TZPATH_CALLBACKS.append(_tzpath_callback)
del _tzpath_callback
else:
def __getattr__(name):
if name == "TZPATH":
return _tzpath.TZPATH
else:
raise AttributeError(
f"module {__name__!r} has no attribute {name!r}"
)
def __dir__():
return sorted(list(globals()) + ["TZPATH"])

View file

@ -0,0 +1,45 @@
import os
import typing
from datetime import datetime, tzinfo
from typing import (
Any,
Iterable,
Optional,
Protocol,
Sequence,
Set,
Type,
Union,
)
_T = typing.TypeVar("_T", bound="ZoneInfo")
class _IOBytes(Protocol):
def read(self, __size: int) -> bytes: ...
def seek(self, __size: int, __whence: int = ...) -> Any: ...
class ZoneInfo(tzinfo):
@property
def key(self) -> str: ...
def __init__(self, key: str) -> None: ...
@classmethod
def no_cache(cls: Type[_T], key: str) -> _T: ...
@classmethod
def from_file(
cls: Type[_T], __fobj: _IOBytes, key: Optional[str] = ...
) -> _T: ...
@classmethod
def clear_cache(cls, *, only_keys: Iterable[str] = ...) -> None: ...
# Note: Both here and in clear_cache, the types allow the use of `str` where
# a sequence of strings is required. This should be remedied if a solution
# to this typing bug is found: https://github.com/python/typing/issues/256
def reset_tzpath(
to: Optional[Sequence[Union[os.PathLike, str]]] = ...
) -> None: ...
def available_timezones() -> Set[str]: ...
TZPATH: Sequence[str]
class ZoneInfoNotFoundError(KeyError): ...
class InvalidTZPathWarning(RuntimeWarning): ...

View file

@ -0,0 +1,171 @@
import struct
def load_tzdata(key):
try:
import importlib.resources as importlib_resources
except ImportError:
import importlib_resources
components = key.split("/")
package_name = ".".join(["tzdata.zoneinfo"] + components[:-1])
resource_name = components[-1]
try:
return importlib_resources.open_binary(package_name, resource_name)
except (ImportError, FileNotFoundError, UnicodeEncodeError):
# There are three types of exception that can be raised that all amount
# to "we cannot find this key":
#
# ImportError: If package_name doesn't exist (e.g. if tzdata is not
# installed, or if there's an error in the folder name like
# Amrica/New_York)
# FileNotFoundError: If resource_name doesn't exist in the package
# (e.g. Europe/Krasnoy)
# UnicodeEncodeError: If package_name or resource_name are not UTF-8,
# such as keys containing a surrogate character.
raise ZoneInfoNotFoundError(f"No time zone found with key {key}")
def load_data(fobj):
header = _TZifHeader.from_file(fobj)
if header.version == 1:
time_size = 4
time_type = "l"
else:
# Version 2+ has 64-bit integer transition times
time_size = 8
time_type = "q"
# Version 2+ also starts with a Version 1 header and data, which
# we need to skip now
skip_bytes = (
header.timecnt * 5 # Transition times and types
+ header.typecnt * 6 # Local time type records
+ header.charcnt # Time zone designations
+ header.leapcnt * 8 # Leap second records
+ header.isstdcnt # Standard/wall indicators
+ header.isutcnt # UT/local indicators
)
fobj.seek(skip_bytes, 1)
# Now we need to read the second header, which is not the same
# as the first
header = _TZifHeader.from_file(fobj)
typecnt = header.typecnt
timecnt = header.timecnt
charcnt = header.charcnt
# The data portion starts with timecnt transitions and indices
if timecnt:
trans_list_utc = struct.unpack(
f">{timecnt}{time_type}", fobj.read(timecnt * time_size)
)
trans_idx = struct.unpack(f">{timecnt}B", fobj.read(timecnt))
else:
trans_list_utc = ()
trans_idx = ()
# Read the ttinfo struct, (utoff, isdst, abbrind)
if typecnt:
utcoff, isdst, abbrind = zip(
*(struct.unpack(">lbb", fobj.read(6)) for i in range(typecnt))
)
else:
utcoff = ()
isdst = ()
abbrind = ()
# Now read the abbreviations. They are null-terminated strings, indexed
# not by position in the array but by position in the unsplit
# abbreviation string. I suppose this makes more sense in C, which uses
# null to terminate the strings, but it's inconvenient here...
abbr_vals = {}
abbr_chars = fobj.read(charcnt)
def get_abbr(idx):
# Gets a string starting at idx and running until the next \x00
#
# We cannot pre-populate abbr_vals by splitting on \x00 because there
# are some zones that use subsets of longer abbreviations, like so:
#
# LMT\x00AHST\x00HDT\x00
#
# Where the idx to abbr mapping should be:
#
# {0: "LMT", 4: "AHST", 5: "HST", 9: "HDT"}
if idx not in abbr_vals:
span_end = abbr_chars.find(b"\x00", idx)
abbr_vals[idx] = abbr_chars[idx:span_end].decode()
return abbr_vals[idx]
abbr = tuple(get_abbr(idx) for idx in abbrind)
# The remainder of the file consists of leap seconds (currently unused) and
# the standard/wall and ut/local indicators, which are metadata we don't need.
# In version 2 files, we need to skip the unnecessary data to get at the TZ string:
if header.version >= 2:
# Each leap second record has size (time_size + 4)
skip_bytes = header.isutcnt + header.isstdcnt + header.leapcnt * 12
fobj.seek(skip_bytes, 1)
c = fobj.read(1) # Should be \n
assert c == b"\n", c
tz_bytes = b""
while True:
c = fobj.read(1)
if c == b"\n":
break
tz_bytes += c
tz_str = tz_bytes
else:
tz_str = None
return trans_idx, trans_list_utc, utcoff, isdst, abbr, tz_str
class _TZifHeader:
__slots__ = [
"version",
"isutcnt",
"isstdcnt",
"leapcnt",
"timecnt",
"typecnt",
"charcnt",
]
def __init__(self, *args):
assert len(self.__slots__) == len(args)
for attr, val in zip(self.__slots__, args):
setattr(self, attr, val)
@classmethod
def from_file(cls, stream):
# The header starts with a 4-byte "magic" value
if stream.read(4) != b"TZif":
raise ValueError("Invalid TZif file: magic not found")
_version = stream.read(1)
if _version == b"\x00":
version = 1
else:
version = int(_version)
stream.read(15)
args = (version,)
# Slots are defined in the order that the bytes are arranged
args = args + struct.unpack(">6l", stream.read(24))
return cls(*args)
class ZoneInfoNotFoundError(KeyError):
"""Exception raised when a ZoneInfo key is not found."""

View file

@ -0,0 +1,207 @@
import os
import sys
PY36 = sys.version_info < (3, 7)
def reset_tzpath(to=None):
global TZPATH
tzpaths = to
if tzpaths is not None:
if isinstance(tzpaths, (str, bytes)):
raise TypeError(
f"tzpaths must be a list or tuple, "
+ f"not {type(tzpaths)}: {tzpaths!r}"
)
if not all(map(os.path.isabs, tzpaths)):
raise ValueError(_get_invalid_paths_message(tzpaths))
base_tzpath = tzpaths
else:
env_var = os.environ.get("PYTHONTZPATH", None)
if env_var is not None:
base_tzpath = _parse_python_tzpath(env_var)
elif sys.platform != "win32":
base_tzpath = [
"/usr/share/zoneinfo",
"/usr/lib/zoneinfo",
"/usr/share/lib/zoneinfo",
"/etc/zoneinfo",
]
base_tzpath.sort(key=lambda x: not os.path.exists(x))
else:
base_tzpath = ()
TZPATH = tuple(base_tzpath)
if TZPATH_CALLBACKS:
for callback in TZPATH_CALLBACKS:
callback(TZPATH)
def _parse_python_tzpath(env_var):
if not env_var:
return ()
raw_tzpath = env_var.split(os.pathsep)
new_tzpath = tuple(filter(os.path.isabs, raw_tzpath))
# If anything has been filtered out, we will warn about it
if len(new_tzpath) != len(raw_tzpath):
import warnings
msg = _get_invalid_paths_message(raw_tzpath)
warnings.warn(
"Invalid paths specified in PYTHONTZPATH environment variable."
+ msg,
InvalidTZPathWarning,
)
return new_tzpath
def _get_invalid_paths_message(tzpaths):
invalid_paths = (path for path in tzpaths if not os.path.isabs(path))
prefix = "\n "
indented_str = prefix + prefix.join(invalid_paths)
return (
"Paths should be absolute but found the following relative paths:"
+ indented_str
)
if sys.version_info < (3, 8):
def _isfile(path):
# bpo-33721: In Python 3.8 non-UTF8 paths return False rather than
# raising an error. See https://bugs.python.org/issue33721
try:
return os.path.isfile(path)
except ValueError:
return False
else:
_isfile = os.path.isfile
def find_tzfile(key):
"""Retrieve the path to a TZif file from a key."""
_validate_tzfile_path(key)
for search_path in TZPATH:
filepath = os.path.join(search_path, key)
if _isfile(filepath):
return filepath
return None
_TEST_PATH = os.path.normpath(os.path.join("_", "_"))[:-1]
def _validate_tzfile_path(path, _base=_TEST_PATH):
if os.path.isabs(path):
raise ValueError(
f"ZoneInfo keys may not be absolute paths, got: {path}"
)
# We only care about the kinds of path normalizations that would change the
# length of the key - e.g. a/../b -> a/b, or a/b/ -> a/b. On Windows,
# normpath will also change from a/b to a\b, but that would still preserve
# the length.
new_path = os.path.normpath(path)
if len(new_path) != len(path):
raise ValueError(
f"ZoneInfo keys must be normalized relative paths, got: {path}"
)
resolved = os.path.normpath(os.path.join(_base, new_path))
if not resolved.startswith(_base):
raise ValueError(
f"ZoneInfo keys must refer to subdirectories of TZPATH, got: {path}"
)
del _TEST_PATH
def available_timezones():
"""Returns a set containing all available time zones.
.. caution::
This may attempt to open a large number of files, since the best way to
determine if a given file on the time zone search path is to open it
and check for the "magic string" at the beginning.
"""
try:
from importlib import resources
except ImportError:
import importlib_resources as resources
valid_zones = set()
# Start with loading from the tzdata package if it exists: this has a
# pre-assembled list of zones that only requires opening one file.
try:
with resources.open_text("tzdata", "zones") as f:
for zone in f:
zone = zone.strip()
if zone:
valid_zones.add(zone)
except (ImportError, FileNotFoundError):
pass
def valid_key(fpath):
try:
with open(fpath, "rb") as f:
return f.read(4) == b"TZif"
except Exception: # pragma: nocover
return False
for tz_root in TZPATH:
if not os.path.exists(tz_root):
continue
for root, dirnames, files in os.walk(tz_root):
if root == tz_root:
# right/ and posix/ are special directories and shouldn't be
# included in the output of available zones
if "right" in dirnames:
dirnames.remove("right")
if "posix" in dirnames:
dirnames.remove("posix")
for file in files:
fpath = os.path.join(root, file)
key = os.path.relpath(fpath, start=tz_root)
if os.sep != "/": # pragma: nocover
key = key.replace(os.sep, "/")
if not key or key in valid_zones:
continue
if valid_key(fpath):
valid_zones.add(key)
if "posixrules" in valid_zones:
# posixrules is a special symlink-only time zone where it exists, it
# should not be included in the output
valid_zones.remove("posixrules")
return valid_zones
class InvalidTZPathWarning(RuntimeWarning):
"""Warning raised if an invalid path is specified in PYTHONTZPATH."""
TZPATH = ()
TZPATH_CALLBACKS = []
reset_tzpath()

View file

@ -0,0 +1 @@
__version__ = "0.2.1"

View file

@ -0,0 +1,754 @@
import bisect
import calendar
import collections
import functools
import re
import weakref
from datetime import datetime, timedelta, tzinfo
from . import _common, _tzpath
EPOCH = datetime(1970, 1, 1)
EPOCHORDINAL = datetime(1970, 1, 1).toordinal()
# It is relatively expensive to construct new timedelta objects, and in most
# cases we're looking at the same deltas, like integer numbers of hours, etc.
# To improve speed and memory use, we'll keep a dictionary with references
# to the ones we've already used so far.
#
# Loading every time zone in the 2020a version of the time zone database
# requires 447 timedeltas, which requires approximately the amount of space
# that ZoneInfo("America/New_York") with 236 transitions takes up, so we will
# set the cache size to 512 so that in the common case we always get cache
# hits, but specifically crafted ZoneInfo objects don't leak arbitrary amounts
# of memory.
@functools.lru_cache(maxsize=512)
def _load_timedelta(seconds):
return timedelta(seconds=seconds)
class ZoneInfo(tzinfo):
_strong_cache_size = 8
_strong_cache = collections.OrderedDict()
_weak_cache = weakref.WeakValueDictionary()
__module__ = "backports.zoneinfo"
def __init_subclass__(cls):
cls._strong_cache = collections.OrderedDict()
cls._weak_cache = weakref.WeakValueDictionary()
def __new__(cls, key):
instance = cls._weak_cache.get(key, None)
if instance is None:
instance = cls._weak_cache.setdefault(key, cls._new_instance(key))
instance._from_cache = True
# Update the "strong" cache
cls._strong_cache[key] = cls._strong_cache.pop(key, instance)
if len(cls._strong_cache) > cls._strong_cache_size:
cls._strong_cache.popitem(last=False)
return instance
@classmethod
def no_cache(cls, key):
obj = cls._new_instance(key)
obj._from_cache = False
return obj
@classmethod
def _new_instance(cls, key):
obj = super().__new__(cls)
obj._key = key
obj._file_path = obj._find_tzfile(key)
if obj._file_path is not None:
file_obj = open(obj._file_path, "rb")
else:
file_obj = _common.load_tzdata(key)
with file_obj as f:
obj._load_file(f)
return obj
@classmethod
def from_file(cls, fobj, key=None):
obj = super().__new__(cls)
obj._key = key
obj._file_path = None
obj._load_file(fobj)
obj._file_repr = repr(fobj)
# Disable pickling for objects created from files
obj.__reduce__ = obj._file_reduce
return obj
@classmethod
def clear_cache(cls, *, only_keys=None):
if only_keys is not None:
for key in only_keys:
cls._weak_cache.pop(key, None)
cls._strong_cache.pop(key, None)
else:
cls._weak_cache.clear()
cls._strong_cache.clear()
@property
def key(self):
return self._key
def utcoffset(self, dt):
return self._find_trans(dt).utcoff
def dst(self, dt):
return self._find_trans(dt).dstoff
def tzname(self, dt):
return self._find_trans(dt).tzname
def fromutc(self, dt):
"""Convert from datetime in UTC to datetime in local time"""
if not isinstance(dt, datetime):
raise TypeError("fromutc() requires a datetime argument")
if dt.tzinfo is not self:
raise ValueError("dt.tzinfo is not self")
timestamp = self._get_local_timestamp(dt)
num_trans = len(self._trans_utc)
if num_trans >= 1 and timestamp < self._trans_utc[0]:
tti = self._tti_before
fold = 0
elif (
num_trans == 0 or timestamp > self._trans_utc[-1]
) and not isinstance(self._tz_after, _ttinfo):
tti, fold = self._tz_after.get_trans_info_fromutc(
timestamp, dt.year
)
elif num_trans == 0:
tti = self._tz_after
fold = 0
else:
idx = bisect.bisect_right(self._trans_utc, timestamp)
if num_trans > 1 and timestamp >= self._trans_utc[1]:
tti_prev, tti = self._ttinfos[idx - 2 : idx]
elif timestamp > self._trans_utc[-1]:
tti_prev = self._ttinfos[-1]
tti = self._tz_after
else:
tti_prev = self._tti_before
tti = self._ttinfos[0]
# Detect fold
shift = tti_prev.utcoff - tti.utcoff
fold = shift.total_seconds() > timestamp - self._trans_utc[idx - 1]
dt += tti.utcoff
if fold:
return dt.replace(fold=1)
else:
return dt
def _find_trans(self, dt):
if dt is None:
if self._fixed_offset:
return self._tz_after
else:
return _NO_TTINFO
ts = self._get_local_timestamp(dt)
lt = self._trans_local[dt.fold]
num_trans = len(lt)
if num_trans and ts < lt[0]:
return self._tti_before
elif not num_trans or ts > lt[-1]:
if isinstance(self._tz_after, _TZStr):
return self._tz_after.get_trans_info(ts, dt.year, dt.fold)
else:
return self._tz_after
else:
# idx is the transition that occurs after this timestamp, so we
# subtract off 1 to get the current ttinfo
idx = bisect.bisect_right(lt, ts) - 1
assert idx >= 0
return self._ttinfos[idx]
def _get_local_timestamp(self, dt):
return (
(dt.toordinal() - EPOCHORDINAL) * 86400
+ dt.hour * 3600
+ dt.minute * 60
+ dt.second
)
def __str__(self):
if self._key is not None:
return f"{self._key}"
else:
return repr(self)
def __repr__(self):
if self._key is not None:
return f"{self.__class__.__name__}(key={self._key!r})"
else:
return f"{self.__class__.__name__}.from_file({self._file_repr})"
def __reduce__(self):
return (self.__class__._unpickle, (self._key, self._from_cache))
def _file_reduce(self):
import pickle
raise pickle.PicklingError(
"Cannot pickle a ZoneInfo file created from a file stream."
)
@classmethod
def _unpickle(cls, key, from_cache):
if from_cache:
return cls(key)
else:
return cls.no_cache(key)
def _find_tzfile(self, key):
return _tzpath.find_tzfile(key)
def _load_file(self, fobj):
# Retrieve all the data as it exists in the zoneinfo file
trans_idx, trans_utc, utcoff, isdst, abbr, tz_str = _common.load_data(
fobj
)
# Infer the DST offsets (needed for .dst()) from the data
dstoff = self._utcoff_to_dstoff(trans_idx, utcoff, isdst)
# Convert all the transition times (UTC) into "seconds since 1970-01-01 local time"
trans_local = self._ts_to_local(trans_idx, trans_utc, utcoff)
# Construct `_ttinfo` objects for each transition in the file
_ttinfo_list = [
_ttinfo(
_load_timedelta(utcoffset), _load_timedelta(dstoffset), tzname
)
for utcoffset, dstoffset, tzname in zip(utcoff, dstoff, abbr)
]
self._trans_utc = trans_utc
self._trans_local = trans_local
self._ttinfos = [_ttinfo_list[idx] for idx in trans_idx]
# Find the first non-DST transition
for i in range(len(isdst)):
if not isdst[i]:
self._tti_before = _ttinfo_list[i]
break
else:
if self._ttinfos:
self._tti_before = self._ttinfos[0]
else:
self._tti_before = None
# Set the "fallback" time zone
if tz_str is not None and tz_str != b"":
self._tz_after = _parse_tz_str(tz_str.decode())
else:
if not self._ttinfos and not _ttinfo_list:
raise ValueError("No time zone information found.")
if self._ttinfos:
self._tz_after = self._ttinfos[-1]
else:
self._tz_after = _ttinfo_list[-1]
# Determine if this is a "fixed offset" zone, meaning that the output
# of the utcoffset, dst and tzname functions does not depend on the
# specific datetime passed.
#
# We make three simplifying assumptions here:
#
# 1. If _tz_after is not a _ttinfo, it has transitions that might
# actually occur (it is possible to construct TZ strings that
# specify STD and DST but no transitions ever occur, such as
# AAA0BBB,0/0,J365/25).
# 2. If _ttinfo_list contains more than one _ttinfo object, the objects
# represent different offsets.
# 3. _ttinfo_list contains no unused _ttinfos (in which case an
# otherwise fixed-offset zone with extra _ttinfos defined may
# appear to *not* be a fixed offset zone).
#
# Violations to these assumptions would be fairly exotic, and exotic
# zones should almost certainly not be used with datetime.time (the
# only thing that would be affected by this).
if len(_ttinfo_list) > 1 or not isinstance(self._tz_after, _ttinfo):
self._fixed_offset = False
elif not _ttinfo_list:
self._fixed_offset = True
else:
self._fixed_offset = _ttinfo_list[0] == self._tz_after
@staticmethod
def _utcoff_to_dstoff(trans_idx, utcoffsets, isdsts):
# Now we must transform our ttis and abbrs into `_ttinfo` objects,
# but there is an issue: .dst() must return a timedelta with the
# difference between utcoffset() and the "standard" offset, but
# the "base offset" and "DST offset" are not encoded in the file;
# we can infer what they are from the isdst flag, but it is not
# sufficient to to just look at the last standard offset, because
# occasionally countries will shift both DST offset and base offset.
typecnt = len(isdsts)
dstoffs = [0] * typecnt # Provisionally assign all to 0.
dst_cnt = sum(isdsts)
dst_found = 0
for i in range(1, len(trans_idx)):
if dst_cnt == dst_found:
break
idx = trans_idx[i]
dst = isdsts[idx]
# We're only going to look at daylight saving time
if not dst:
continue
# Skip any offsets that have already been assigned
if dstoffs[idx] != 0:
continue
dstoff = 0
utcoff = utcoffsets[idx]
comp_idx = trans_idx[i - 1]
if not isdsts[comp_idx]:
dstoff = utcoff - utcoffsets[comp_idx]
if not dstoff and idx < (typecnt - 1):
comp_idx = trans_idx[i + 1]
# If the following transition is also DST and we couldn't
# find the DST offset by this point, we're going ot have to
# skip it and hope this transition gets assigned later
if isdsts[comp_idx]:
continue
dstoff = utcoff - utcoffsets[comp_idx]
if dstoff:
dst_found += 1
dstoffs[idx] = dstoff
else:
# If we didn't find a valid value for a given index, we'll end up
# with dstoff = 0 for something where `isdst=1`. This is obviously
# wrong - one hour will be a much better guess than 0
for idx in range(typecnt):
if not dstoffs[idx] and isdsts[idx]:
dstoffs[idx] = 3600
return dstoffs
@staticmethod
def _ts_to_local(trans_idx, trans_list_utc, utcoffsets):
"""Generate number of seconds since 1970 *in the local time*.
This is necessary to easily find the transition times in local time"""
if not trans_list_utc:
return [[], []]
# Start with the timestamps and modify in-place
trans_list_wall = [list(trans_list_utc), list(trans_list_utc)]
if len(utcoffsets) > 1:
offset_0 = utcoffsets[0]
offset_1 = utcoffsets[trans_idx[0]]
if offset_1 > offset_0:
offset_1, offset_0 = offset_0, offset_1
else:
offset_0 = offset_1 = utcoffsets[0]
trans_list_wall[0][0] += offset_0
trans_list_wall[1][0] += offset_1
for i in range(1, len(trans_idx)):
offset_0 = utcoffsets[trans_idx[i - 1]]
offset_1 = utcoffsets[trans_idx[i]]
if offset_1 > offset_0:
offset_1, offset_0 = offset_0, offset_1
trans_list_wall[0][i] += offset_0
trans_list_wall[1][i] += offset_1
return trans_list_wall
class _ttinfo:
__slots__ = ["utcoff", "dstoff", "tzname"]
def __init__(self, utcoff, dstoff, tzname):
self.utcoff = utcoff
self.dstoff = dstoff
self.tzname = tzname
def __eq__(self, other):
return (
self.utcoff == other.utcoff
and self.dstoff == other.dstoff
and self.tzname == other.tzname
)
def __repr__(self): # pragma: nocover
return (
f"{self.__class__.__name__}"
+ f"({self.utcoff}, {self.dstoff}, {self.tzname})"
)
_NO_TTINFO = _ttinfo(None, None, None)
class _TZStr:
__slots__ = (
"std",
"dst",
"start",
"end",
"get_trans_info",
"get_trans_info_fromutc",
"dst_diff",
)
def __init__(
self, std_abbr, std_offset, dst_abbr, dst_offset, start=None, end=None
):
self.dst_diff = dst_offset - std_offset
std_offset = _load_timedelta(std_offset)
self.std = _ttinfo(
utcoff=std_offset, dstoff=_load_timedelta(0), tzname=std_abbr
)
self.start = start
self.end = end
dst_offset = _load_timedelta(dst_offset)
delta = _load_timedelta(self.dst_diff)
self.dst = _ttinfo(utcoff=dst_offset, dstoff=delta, tzname=dst_abbr)
# These are assertions because the constructor should only be called
# by functions that would fail before passing start or end
assert start is not None, "No transition start specified"
assert end is not None, "No transition end specified"
self.get_trans_info = self._get_trans_info
self.get_trans_info_fromutc = self._get_trans_info_fromutc
def transitions(self, year):
start = self.start.year_to_epoch(year)
end = self.end.year_to_epoch(year)
return start, end
def _get_trans_info(self, ts, year, fold):
"""Get the information about the current transition - tti"""
start, end = self.transitions(year)
# With fold = 0, the period (denominated in local time) with the
# smaller offset starts at the end of the gap and ends at the end of
# the fold; with fold = 1, it runs from the start of the gap to the
# beginning of the fold.
#
# So in order to determine the DST boundaries we need to know both
# the fold and whether DST is positive or negative (rare), and it
# turns out that this boils down to fold XOR is_positive.
if fold == (self.dst_diff >= 0):
end -= self.dst_diff
else:
start += self.dst_diff
if start < end:
isdst = start <= ts < end
else:
isdst = not (end <= ts < start)
return self.dst if isdst else self.std
def _get_trans_info_fromutc(self, ts, year):
start, end = self.transitions(year)
start -= self.std.utcoff.total_seconds()
end -= self.dst.utcoff.total_seconds()
if start < end:
isdst = start <= ts < end
else:
isdst = not (end <= ts < start)
# For positive DST, the ambiguous period is one dst_diff after the end
# of DST; for negative DST, the ambiguous period is one dst_diff before
# the start of DST.
if self.dst_diff > 0:
ambig_start = end
ambig_end = end + self.dst_diff
else:
ambig_start = start
ambig_end = start - self.dst_diff
fold = ambig_start <= ts < ambig_end
return (self.dst if isdst else self.std, fold)
def _post_epoch_days_before_year(year):
"""Get the number of days between 1970-01-01 and YEAR-01-01"""
y = year - 1
return y * 365 + y // 4 - y // 100 + y // 400 - EPOCHORDINAL
class _DayOffset:
__slots__ = ["d", "julian", "hour", "minute", "second"]
def __init__(self, d, julian, hour=2, minute=0, second=0):
if not (0 + julian) <= d <= 365:
min_day = 0 + julian
raise ValueError(f"d must be in [{min_day}, 365], not: {d}")
self.d = d
self.julian = julian
self.hour = hour
self.minute = minute
self.second = second
def year_to_epoch(self, year):
days_before_year = _post_epoch_days_before_year(year)
d = self.d
if self.julian and d >= 59 and calendar.isleap(year):
d += 1
epoch = (days_before_year + d) * 86400
epoch += self.hour * 3600 + self.minute * 60 + self.second
return epoch
class _CalendarOffset:
__slots__ = ["m", "w", "d", "hour", "minute", "second"]
_DAYS_BEFORE_MONTH = (
-1,
0,
31,
59,
90,
120,
151,
181,
212,
243,
273,
304,
334,
)
def __init__(self, m, w, d, hour=2, minute=0, second=0):
if not 0 < m <= 12:
raise ValueError("m must be in (0, 12]")
if not 0 < w <= 5:
raise ValueError("w must be in (0, 5]")
if not 0 <= d <= 6:
raise ValueError("d must be in [0, 6]")
self.m = m
self.w = w
self.d = d
self.hour = hour
self.minute = minute
self.second = second
@classmethod
def _ymd2ord(cls, year, month, day):
return (
_post_epoch_days_before_year(year)
+ cls._DAYS_BEFORE_MONTH[month]
+ (month > 2 and calendar.isleap(year))
+ day
)
# TODO: These are not actually epoch dates as they are expressed in local time
def year_to_epoch(self, year):
"""Calculates the datetime of the occurrence from the year"""
# We know year and month, we need to convert w, d into day of month
#
# Week 1 is the first week in which day `d` (where 0 = Sunday) appears.
# Week 5 represents the last occurrence of day `d`, so we need to know
# the range of the month.
first_day, days_in_month = calendar.monthrange(year, self.m)
# This equation seems magical, so I'll break it down:
# 1. calendar says 0 = Monday, POSIX says 0 = Sunday
# so we need first_day + 1 to get 1 = Monday -> 7 = Sunday,
# which is still equivalent because this math is mod 7
# 2. Get first day - desired day mod 7: -1 % 7 = 6, so we don't need
# to do anything to adjust negative numbers.
# 3. Add 1 because month days are a 1-based index.
month_day = (self.d - (first_day + 1)) % 7 + 1
# Now use a 0-based index version of `w` to calculate the w-th
# occurrence of `d`
month_day += (self.w - 1) * 7
# month_day will only be > days_in_month if w was 5, and `w` means
# "last occurrence of `d`", so now we just check if we over-shot the
# end of the month and if so knock off 1 week.
if month_day > days_in_month:
month_day -= 7
ordinal = self._ymd2ord(year, self.m, month_day)
epoch = ordinal * 86400
epoch += self.hour * 3600 + self.minute * 60 + self.second
return epoch
def _parse_tz_str(tz_str):
# The tz string has the format:
#
# std[offset[dst[offset],start[/time],end[/time]]]
#
# std and dst must be 3 or more characters long and must not contain
# a leading colon, embedded digits, commas, nor a plus or minus signs;
# The spaces between "std" and "offset" are only for display and are
# not actually present in the string.
#
# The format of the offset is ``[+|-]hh[:mm[:ss]]``
offset_str, *start_end_str = tz_str.split(",", 1)
# fmt: off
parser_re = re.compile(
r"(?P<std>[^<0-9:.+-]+|<[a-zA-Z0-9+\-]+>)" +
r"((?P<stdoff>[+-]?\d{1,2}(:\d{2}(:\d{2})?)?)" +
r"((?P<dst>[^0-9:.+-]+|<[a-zA-Z0-9+\-]+>)" +
r"((?P<dstoff>[+-]?\d{1,2}(:\d{2}(:\d{2})?)?))?" +
r")?" + # dst
r")?$" # stdoff
)
# fmt: on
m = parser_re.match(offset_str)
if m is None:
raise ValueError(f"{tz_str} is not a valid TZ string")
std_abbr = m.group("std")
dst_abbr = m.group("dst")
dst_offset = None
std_abbr = std_abbr.strip("<>")
if dst_abbr:
dst_abbr = dst_abbr.strip("<>")
std_offset = m.group("stdoff")
if std_offset:
try:
std_offset = _parse_tz_delta(std_offset)
except ValueError as e:
raise ValueError(f"Invalid STD offset in {tz_str}") from e
else:
std_offset = 0
if dst_abbr is not None:
dst_offset = m.group("dstoff")
if dst_offset:
try:
dst_offset = _parse_tz_delta(dst_offset)
except ValueError as e:
raise ValueError(f"Invalid DST offset in {tz_str}") from e
else:
dst_offset = std_offset + 3600
if not start_end_str:
raise ValueError(f"Missing transition rules: {tz_str}")
start_end_strs = start_end_str[0].split(",", 1)
try:
start, end = (_parse_dst_start_end(x) for x in start_end_strs)
except ValueError as e:
raise ValueError(f"Invalid TZ string: {tz_str}") from e
return _TZStr(std_abbr, std_offset, dst_abbr, dst_offset, start, end)
elif start_end_str:
raise ValueError(f"Transition rule present without DST: {tz_str}")
else:
# This is a static ttinfo, don't return _TZStr
return _ttinfo(
_load_timedelta(std_offset), _load_timedelta(0), std_abbr
)
def _parse_dst_start_end(dststr):
date, *time = dststr.split("/")
if date[0] == "M":
n_is_julian = False
m = re.match(r"M(\d{1,2})\.(\d).(\d)$", date)
if m is None:
raise ValueError(f"Invalid dst start/end date: {dststr}")
date_offset = tuple(map(int, m.groups()))
offset = _CalendarOffset(*date_offset)
else:
if date[0] == "J":
n_is_julian = True
date = date[1:]
else:
n_is_julian = False
doy = int(date)
offset = _DayOffset(doy, n_is_julian)
if time:
time_components = list(map(int, time[0].split(":")))
n_components = len(time_components)
if n_components < 3:
time_components.extend([0] * (3 - n_components))
offset.hour, offset.minute, offset.second = time_components
return offset
def _parse_tz_delta(tz_delta):
match = re.match(
r"(?P<sign>[+-])?(?P<h>\d{1,2})(:(?P<m>\d{2})(:(?P<s>\d{2}))?)?",
tz_delta,
)
# Anything passed to this function should already have hit an equivalent
# regular expression to find the section to parse.
assert match is not None, tz_delta
h, m, s = (
int(v) if v is not None else 0
for v in map(match.group, ("h", "m", "s"))
)
total = h * 3600 + m * 60 + s
if not -86400 < total < 86400:
raise ValueError(
"Offset must be strictly between -24h and +24h:" + tz_delta
)
# Yes, +5 maps to an offset of -5h
if match.group("sign") != "-":
total *= -1
return total

View file

@ -1 +0,0 @@
__version__ = '1.10.0'

View file

@ -1,169 +0,0 @@
from __future__ import absolute_import
import sys
# True if we are running on Python 2.
PY2 = sys.version_info[0] == 2
PYVER = sys.version_info[:2]
JYTHON = sys.platform.startswith('java')
if PY2 and not JYTHON: # pragma: no cover
import cPickle as pickle
else: # pragma: no cover
import pickle
if not PY2: # pragma: no cover
xrange_ = range
NoneType = type(None)
string_type = str
unicode_text = str
byte_string = bytes
from urllib.parse import urlencode as url_encode
from urllib.parse import quote as url_quote
from urllib.parse import unquote as url_unquote
from urllib.parse import urlparse as url_parse
from urllib.request import url2pathname
import http.cookies as http_cookies
from base64 import b64decode as _b64decode, b64encode as _b64encode
try:
import dbm as anydbm
except:
import dumbdbm as anydbm
def b64decode(b):
return _b64decode(b.encode('ascii'))
def b64encode(s):
return _b64encode(s).decode('ascii')
def u_(s):
return str(s)
def bytes_(s):
if isinstance(s, byte_string):
return s
return str(s).encode('ascii', 'strict')
def dictkeyslist(d):
return list(d.keys())
else:
xrange_ = xrange
from types import NoneType
string_type = basestring
unicode_text = unicode
byte_string = str
from urllib import urlencode as url_encode
from urllib import quote as url_quote
from urllib import unquote as url_unquote
from urlparse import urlparse as url_parse
from urllib import url2pathname
import Cookie as http_cookies
from base64 import b64decode, b64encode
import anydbm
def u_(s):
if isinstance(s, unicode_text):
return s
if not isinstance(s, byte_string):
s = str(s)
return unicode(s, 'utf-8')
def bytes_(s):
if isinstance(s, byte_string):
return s
return str(s)
def dictkeyslist(d):
return d.keys()
def im_func(f):
if not PY2: # pragma: no cover
return getattr(f, '__func__', None)
else:
return getattr(f, 'im_func', None)
def default_im_func(f):
if not PY2: # pragma: no cover
return getattr(f, '__func__', f)
else:
return getattr(f, 'im_func', f)
def im_self(f):
if not PY2: # pragma: no cover
return getattr(f, '__self__', None)
else:
return getattr(f, 'im_self', None)
def im_class(f):
if not PY2: # pragma: no cover
self = im_self(f)
if self is not None:
return self.__class__
else:
return None
else:
return getattr(f, 'im_class', None)
def add_metaclass(metaclass):
"""Class decorator for creating a class with a metaclass."""
def wrapper(cls):
orig_vars = cls.__dict__.copy()
slots = orig_vars.get('__slots__')
if slots is not None:
if isinstance(slots, str):
slots = [slots]
for slots_var in slots:
orig_vars.pop(slots_var)
orig_vars.pop('__dict__', None)
orig_vars.pop('__weakref__', None)
return metaclass(cls.__name__, cls.__bases__, orig_vars)
return wrapper
if not PY2: # pragma: no cover
import builtins
exec_ = getattr(builtins, "exec")
def reraise(tp, value, tb=None):
if value.__traceback__ is not tb:
raise value.with_traceback(tb)
raise value
else: # pragma: no cover
def exec_(code, globs=None, locs=None):
"""Execute code in a namespace."""
if globs is None:
frame = sys._getframe(1)
globs = frame.f_globals
if locs is None:
locs = frame.f_locals
del frame
elif locs is None:
locs = globs
exec("""exec code in globs, locs""")
exec_("""def reraise(tp, value, tb=None):
raise tp, value, tb
""")
try:
from inspect import signature as func_signature
except ImportError:
from funcsigs import signature as func_signature
def bindfuncargs(arginfo, args, kwargs):
boundargs = arginfo.bind(*args, **kwargs)
return boundargs.args, boundargs.kwargs

View file

@ -1,615 +0,0 @@
"""This package contains the "front end" classes and functions
for Beaker caching.
Included are the :class:`.Cache` and :class:`.CacheManager` classes,
as well as the function decorators :func:`.region_decorate`,
:func:`.region_invalidate`.
"""
import warnings
from itertools import chain
from beaker._compat import u_, unicode_text, func_signature, bindfuncargs
import beaker.container as container
import beaker.util as util
from beaker.crypto.util import sha1
from beaker.exceptions import BeakerException, InvalidCacheBackendError
from beaker.synchronization import _threading
import beaker.ext.memcached as memcached
import beaker.ext.database as database
import beaker.ext.sqla as sqla
import beaker.ext.google as google
import beaker.ext.mongodb as mongodb
import beaker.ext.redisnm as redisnm
from functools import wraps
# Initialize the cache region dict
cache_regions = {}
"""Dictionary of 'region' arguments.
A "region" is a string name that refers to a series of cache
configuration arguments. An application may have multiple
"regions" - one which stores things in a memory cache, one
which writes data to files, etc.
The dictionary stores string key names mapped to dictionaries
of configuration arguments. Example::
from beaker.cache import cache_regions
cache_regions.update({
'short_term':{
'expire':60,
'type':'memory'
},
'long_term':{
'expire':1800,
'type':'dbm',
'data_dir':'/tmp',
}
})
"""
cache_managers = {}
class _backends(object):
initialized = False
def __init__(self, clsmap):
self._clsmap = clsmap
self._mutex = _threading.Lock()
def __getitem__(self, key):
try:
return self._clsmap[key]
except KeyError as e:
if not self.initialized:
self._mutex.acquire()
try:
if not self.initialized:
self._init()
self.initialized = True
return self._clsmap[key]
finally:
self._mutex.release()
raise e
def _init(self):
try:
import pkg_resources
# Load up the additional entry point defined backends
for entry_point in pkg_resources.iter_entry_points('beaker.backends'):
try:
namespace_manager = entry_point.load()
name = entry_point.name
if name in self._clsmap:
raise BeakerException("NamespaceManager name conflict,'%s' "
"already loaded" % name)
self._clsmap[name] = namespace_manager
except (InvalidCacheBackendError, SyntaxError):
# Ignore invalid backends
pass
except:
import sys
from pkg_resources import DistributionNotFound
# Warn when there's a problem loading a NamespaceManager
if not isinstance(sys.exc_info()[1], DistributionNotFound):
import traceback
try:
from StringIO import StringIO # Python2
except ImportError:
from io import StringIO # Python3
tb = StringIO()
traceback.print_exc(file=tb)
warnings.warn(
"Unable to load NamespaceManager "
"entry point: '%s': %s" % (
entry_point,
tb.getvalue()),
RuntimeWarning, 2)
except ImportError:
pass
# Initialize the basic available backends
clsmap = _backends({
'memory': container.MemoryNamespaceManager,
'dbm': container.DBMNamespaceManager,
'file': container.FileNamespaceManager,
'ext:memcached': memcached.MemcachedNamespaceManager,
'ext:database': database.DatabaseNamespaceManager,
'ext:sqla': sqla.SqlaNamespaceManager,
'ext:google': google.GoogleNamespaceManager,
'ext:mongodb': mongodb.MongoNamespaceManager,
'ext:redis': redisnm.RedisNamespaceManager
})
def cache_region(region, *args):
"""Decorate a function such that its return result is cached,
using a "region" to indicate the cache arguments.
Example::
from beaker.cache import cache_regions, cache_region
# configure regions
cache_regions.update({
'short_term':{
'expire':60,
'type':'memory'
}
})
@cache_region('short_term', 'load_things')
def load(search_term, limit, offset):
'''Load from a database given a search term, limit, offset.'''
return database.query(search_term)[offset:offset + limit]
The decorator can also be used with object methods. The ``self``
argument is not part of the cache key. This is based on the
actual string name ``self`` being in the first argument
position (new in 1.6)::
class MyThing(object):
@cache_region('short_term', 'load_things')
def load(self, search_term, limit, offset):
'''Load from a database given a search term, limit, offset.'''
return database.query(search_term)[offset:offset + limit]
Classmethods work as well - use ``cls`` as the name of the class argument,
and place the decorator around the function underneath ``@classmethod``
(new in 1.6)::
class MyThing(object):
@classmethod
@cache_region('short_term', 'load_things')
def load(cls, search_term, limit, offset):
'''Load from a database given a search term, limit, offset.'''
return database.query(search_term)[offset:offset + limit]
:param region: String name of the region corresponding to the desired
caching arguments, established in :attr:`.cache_regions`.
:param \*args: Optional ``str()``-compatible arguments which will uniquely
identify the key used by this decorated function, in addition
to the positional arguments passed to the function itself at call time.
This is recommended as it is needed to distinguish between any two functions
or methods that have the same name (regardless of parent class or not).
.. note::
The function being decorated must only be called with
positional arguments, and the arguments must support
being stringified with ``str()``. The concatenation
of the ``str()`` version of each argument, combined
with that of the ``*args`` sent to the decorator,
forms the unique cache key.
.. note::
When a method on a class is decorated, the ``self`` or ``cls``
argument in the first position is
not included in the "key" used for caching. New in 1.6.
"""
return _cache_decorate(args, None, None, region)
def region_invalidate(namespace, region, *args):
"""Invalidate a cache region corresponding to a function
decorated with :func:`.cache_region`.
:param namespace: The namespace of the cache to invalidate. This is typically
a reference to the original function (as returned by the :func:`.cache_region`
decorator), where the :func:`.cache_region` decorator applies a "memo" to
the function in order to locate the string name of the namespace.
:param region: String name of the region used with the decorator. This can be
``None`` in the usual case that the decorated function itself is passed,
not the string name of the namespace.
:param args: Stringifyable arguments that are used to locate the correct
key. This consists of the ``*args`` sent to the :func:`.cache_region`
decorator itself, plus the ``*args`` sent to the function itself
at runtime.
Example::
from beaker.cache import cache_regions, cache_region, region_invalidate
# configure regions
cache_regions.update({
'short_term':{
'expire':60,
'type':'memory'
}
})
@cache_region('short_term', 'load_data')
def load(search_term, limit, offset):
'''Load from a database given a search term, limit, offset.'''
return database.query(search_term)[offset:offset + limit]
def invalidate_search(search_term, limit, offset):
'''Invalidate the cached storage for a given search term, limit, offset.'''
region_invalidate(load, 'short_term', 'load_data', search_term, limit, offset)
Note that when a method on a class is decorated, the first argument ``cls``
or ``self`` is not included in the cache key. This means you don't send
it to :func:`.region_invalidate`::
class MyThing(object):
@cache_region('short_term', 'some_data')
def load(self, search_term, limit, offset):
'''Load from a database given a search term, limit, offset.'''
return database.query(search_term)[offset:offset + limit]
def invalidate_search(self, search_term, limit, offset):
'''Invalidate the cached storage for a given search term, limit, offset.'''
region_invalidate(self.load, 'short_term', 'some_data', search_term, limit, offset)
"""
if callable(namespace):
if not region:
region = namespace._arg_region
namespace = namespace._arg_namespace
if not region:
raise BeakerException("Region or callable function "
"namespace is required")
else:
region = cache_regions[region]
cache = Cache._get_cache(namespace, region)
_cache_decorator_invalidate(cache,
region.get('key_length', util.DEFAULT_CACHE_KEY_LENGTH),
args)
class Cache(object):
"""Front-end to the containment API implementing a data cache.
:param namespace: the namespace of this Cache
:param type: type of cache to use
:param expire: seconds to keep cached data
:param expiretime: seconds to keep cached data (legacy support)
:param starttime: time when cache was cache was
"""
def __init__(self, namespace, type='memory', expiretime=None,
starttime=None, expire=None, **nsargs):
try:
cls = clsmap[type]
if isinstance(cls, InvalidCacheBackendError):
raise cls
except KeyError:
raise TypeError("Unknown cache implementation %r" % type)
if expire is not None:
expire = int(expire)
self.namespace_name = namespace
self.namespace = cls(namespace, **nsargs)
self.expiretime = expiretime or expire
self.starttime = starttime
self.nsargs = nsargs
@classmethod
def _get_cache(cls, namespace, kw):
key = namespace + str(kw)
try:
return cache_managers[key]
except KeyError:
cache_managers[key] = cache = cls(namespace, **kw)
return cache
def put(self, key, value, **kw):
self._get_value(key, **kw).set_value(value)
set_value = put
def get(self, key, **kw):
"""Retrieve a cached value from the container"""
return self._get_value(key, **kw).get_value()
get_value = get
def remove_value(self, key, **kw):
mycontainer = self._get_value(key, **kw)
mycontainer.clear_value()
remove = remove_value
def _get_value(self, key, **kw):
if isinstance(key, unicode_text):
key = key.encode('ascii', 'backslashreplace')
if 'type' in kw:
return self._legacy_get_value(key, **kw)
kw.setdefault('expiretime', self.expiretime)
kw.setdefault('starttime', self.starttime)
return container.Value(key, self.namespace, **kw)
@util.deprecated("Specifying a "
"'type' and other namespace configuration with cache.get()/put()/etc. "
"is deprecated. Specify 'type' and other namespace configuration to "
"cache_manager.get_cache() and/or the Cache constructor instead.")
def _legacy_get_value(self, key, type, **kw):
expiretime = kw.pop('expiretime', self.expiretime)
starttime = kw.pop('starttime', None)
createfunc = kw.pop('createfunc', None)
kwargs = self.nsargs.copy()
kwargs.update(kw)
c = Cache(self.namespace.namespace, type=type, **kwargs)
return c._get_value(key, expiretime=expiretime, createfunc=createfunc,
starttime=starttime)
def clear(self):
"""Clear all the values from the namespace"""
self.namespace.remove()
# dict interface
def __getitem__(self, key):
return self.get(key)
def __contains__(self, key):
return self._get_value(key).has_current_value()
def has_key(self, key):
return key in self
def __delitem__(self, key):
self.remove_value(key)
def __setitem__(self, key, value):
self.put(key, value)
class CacheManager(object):
def __init__(self, **kwargs):
"""Initialize a CacheManager object with a set of options
Options should be parsed with the
:func:`~beaker.util.parse_cache_config_options` function to
ensure only valid options are used.
"""
self.kwargs = kwargs
self.regions = kwargs.pop('cache_regions', {})
# Add these regions to the module global
cache_regions.update(self.regions)
def get_cache(self, name, **kwargs):
kw = self.kwargs.copy()
kw.update(kwargs)
return Cache._get_cache(name, kw)
def get_cache_region(self, name, region):
if region not in self.regions:
raise BeakerException('Cache region not configured: %s' % region)
kw = self.regions[region]
return Cache._get_cache(name, kw)
def region(self, region, *args):
"""Decorate a function to cache itself using a cache region
The region decorator requires arguments if there are more than
two of the same named function, in the same module. This is
because the namespace used for the functions cache is based on
the functions name and the module.
Example::
# Assuming a cache object is available like:
cache = CacheManager(dict_of_config_options)
def populate_things():
@cache.region('short_term', 'some_data')
def load(search_term, limit, offset):
return load_the_data(search_term, limit, offset)
return load('rabbits', 20, 0)
.. note::
The function being decorated must only be called with
positional arguments.
"""
return cache_region(region, *args)
def region_invalidate(self, namespace, region, *args):
"""Invalidate a cache region namespace or decorated function
This function only invalidates cache spaces created with the
cache_region decorator.
:param namespace: Either the namespace of the result to invalidate, or the
cached function
:param region: The region the function was cached to. If the function was
cached to a single region then this argument can be None
:param args: Arguments that were used to differentiate the cached
function as well as the arguments passed to the decorated
function
Example::
# Assuming a cache object is available like:
cache = CacheManager(dict_of_config_options)
def populate_things(invalidate=False):
@cache.region('short_term', 'some_data')
def load(search_term, limit, offset):
return load_the_data(search_term, limit, offset)
# If the results should be invalidated first
if invalidate:
cache.region_invalidate(load, None, 'some_data',
'rabbits', 20, 0)
return load('rabbits', 20, 0)
"""
return region_invalidate(namespace, region, *args)
def cache(self, *args, **kwargs):
"""Decorate a function to cache itself with supplied parameters
:param args: Used to make the key unique for this function, as in region()
above.
:param kwargs: Parameters to be passed to get_cache(), will override defaults
Example::
# Assuming a cache object is available like:
cache = CacheManager(dict_of_config_options)
def populate_things():
@cache.cache('mycache', expire=15)
def load(search_term, limit, offset):
return load_the_data(search_term, limit, offset)
return load('rabbits', 20, 0)
.. note::
The function being decorated must only be called with
positional arguments.
"""
return _cache_decorate(args, self, kwargs, None)
def invalidate(self, func, *args, **kwargs):
"""Invalidate a cache decorated function
This function only invalidates cache spaces created with the
cache decorator.
:param func: Decorated function to invalidate
:param args: Used to make the key unique for this function, as in region()
above.
:param kwargs: Parameters that were passed for use by get_cache(), note that
this is only required if a ``type`` was specified for the
function
Example::
# Assuming a cache object is available like:
cache = CacheManager(dict_of_config_options)
def populate_things(invalidate=False):
@cache.cache('mycache', type="file", expire=15)
def load(search_term, limit, offset):
return load_the_data(search_term, limit, offset)
# If the results should be invalidated first
if invalidate:
cache.invalidate(load, 'mycache', 'rabbits', 20, 0, type="file")
return load('rabbits', 20, 0)
"""
namespace = func._arg_namespace
cache = self.get_cache(namespace, **kwargs)
if hasattr(func, '_arg_region'):
cachereg = cache_regions[func._arg_region]
key_length = cachereg.get('key_length', util.DEFAULT_CACHE_KEY_LENGTH)
else:
key_length = kwargs.pop('key_length', util.DEFAULT_CACHE_KEY_LENGTH)
_cache_decorator_invalidate(cache, key_length, args)
def _cache_decorate(deco_args, manager, options, region):
"""Return a caching function decorator."""
cache = [None]
def decorate(func):
namespace = util.func_namespace(func)
skip_self = util.has_self_arg(func)
signature = func_signature(func)
@wraps(func)
def cached(*args, **kwargs):
if not cache[0]:
if region is not None:
if region not in cache_regions:
raise BeakerException(
'Cache region not configured: %s' % region)
reg = cache_regions[region]
if not reg.get('enabled', True):
return func(*args, **kwargs)
cache[0] = Cache._get_cache(namespace, reg)
elif manager:
cache[0] = manager.get_cache(namespace, **options)
else:
raise Exception("'manager + kwargs' or 'region' "
"argument is required")
cache_key_kwargs = []
if kwargs:
# kwargs provided, merge them in positional args
# to avoid having different cache keys.
args, kwargs = bindfuncargs(signature, args, kwargs)
cache_key_kwargs = [u_(':').join((u_(key), u_(value))) for key, value in kwargs.items()]
cache_key_args = args
if skip_self:
cache_key_args = args[1:]
cache_key = u_(" ").join(map(u_, chain(deco_args, cache_key_args, cache_key_kwargs)))
if region:
cachereg = cache_regions[region]
key_length = cachereg.get('key_length', util.DEFAULT_CACHE_KEY_LENGTH)
else:
key_length = options.pop('key_length', util.DEFAULT_CACHE_KEY_LENGTH)
# TODO: This is probably a bug as length is checked before converting to UTF8
# which will cause cache_key to grow in size.
if len(cache_key) + len(namespace) > int(key_length):
cache_key = sha1(cache_key.encode('utf-8')).hexdigest()
def go():
return func(*args, **kwargs)
# save org function name
go.__name__ = '_cached_%s' % (func.__name__,)
return cache[0].get_value(cache_key, createfunc=go)
cached._arg_namespace = namespace
if region is not None:
cached._arg_region = region
return cached
return decorate
def _cache_decorator_invalidate(cache, key_length, args):
"""Invalidate a cache key based on function arguments."""
cache_key = u_(" ").join(map(u_, args))
if len(cache_key) + len(cache.namespace_name) > key_length:
cache_key = sha1(cache_key.encode('utf-8')).hexdigest()
cache.remove_value(cache_key)

View file

@ -1,760 +0,0 @@
"""Container and Namespace classes"""
import errno
from ._compat import pickle, anydbm, add_metaclass, PYVER, unicode_text
import beaker.util as util
import logging
import os
import time
from beaker.exceptions import CreationAbortedError, MissingCacheParameter
from beaker.synchronization import _threading, file_synchronizer, \
mutex_synchronizer, NameLock, null_synchronizer
__all__ = ['Value', 'Container', 'ContainerContext',
'MemoryContainer', 'DBMContainer', 'NamespaceManager',
'MemoryNamespaceManager', 'DBMNamespaceManager', 'FileContainer',
'OpenResourceNamespaceManager',
'FileNamespaceManager', 'CreationAbortedError']
logger = logging.getLogger('beaker.container')
if logger.isEnabledFor(logging.DEBUG):
debug = logger.debug
else:
def debug(message, *args):
pass
class NamespaceManager(object):
"""Handles dictionary operations and locking for a namespace of
values.
:class:`.NamespaceManager` provides a dictionary-like interface,
implementing ``__getitem__()``, ``__setitem__()``, and
``__contains__()``, as well as functions related to lock
acquisition.
The implementation for setting and retrieving the namespace data is
handled by subclasses.
NamespaceManager may be used alone, or may be accessed by
one or more :class:`.Value` objects. :class:`.Value` objects provide per-key
services like expiration times and automatic recreation of values.
Multiple NamespaceManagers created with a particular name will all
share access to the same underlying datasource and will attempt to
synchronize against a common mutex object. The scope of this
sharing may be within a single process or across multiple
processes, depending on the type of NamespaceManager used.
The NamespaceManager itself is generally threadsafe, except in the
case of the DBMNamespaceManager in conjunction with the gdbm dbm
implementation.
"""
@classmethod
def _init_dependencies(cls):
"""Initialize module-level dependent libraries required
by this :class:`.NamespaceManager`."""
def __init__(self, namespace):
self._init_dependencies()
self.namespace = namespace
def get_creation_lock(self, key):
"""Return a locking object that is used to synchronize
multiple threads or processes which wish to generate a new
cache value.
This function is typically an instance of
:class:`.FileSynchronizer`, :class:`.ConditionSynchronizer`,
or :class:`.null_synchronizer`.
The creation lock is only used when a requested value
does not exist, or has been expired, and is only used
by the :class:`.Value` key-management object in conjunction
with a "createfunc" value-creation function.
"""
raise NotImplementedError()
def do_remove(self):
"""Implement removal of the entire contents of this
:class:`.NamespaceManager`.
e.g. for a file-based namespace, this would remove
all the files.
The front-end to this method is the
:meth:`.NamespaceManager.remove` method.
"""
raise NotImplementedError()
def acquire_read_lock(self):
"""Establish a read lock.
This operation is called before a key is read. By
default the function does nothing.
"""
def release_read_lock(self):
"""Release a read lock.
This operation is called after a key is read. By
default the function does nothing.
"""
def acquire_write_lock(self, wait=True, replace=False):
"""Establish a write lock.
This operation is called before a key is written.
A return value of ``True`` indicates the lock has
been acquired.
By default the function returns ``True`` unconditionally.
'replace' is a hint indicating the full contents
of the namespace may be safely discarded. Some backends
may implement this (i.e. file backend won't unpickle the
current contents).
"""
return True
def release_write_lock(self):
"""Release a write lock.
This operation is called after a new value is written.
By default this function does nothing.
"""
def has_key(self, key):
"""Return ``True`` if the given key is present in this
:class:`.Namespace`.
"""
return self.__contains__(key)
def __getitem__(self, key):
raise NotImplementedError()
def __setitem__(self, key, value):
raise NotImplementedError()
def set_value(self, key, value, expiretime=None):
"""Sets a value in this :class:`.NamespaceManager`.
This is the same as ``__setitem__()``, but
also allows an expiration time to be passed
at the same time.
"""
self[key] = value
def __contains__(self, key):
raise NotImplementedError()
def __delitem__(self, key):
raise NotImplementedError()
def keys(self):
"""Return the list of all keys.
This method may not be supported by all
:class:`.NamespaceManager` implementations.
"""
raise NotImplementedError()
def remove(self):
"""Remove the entire contents of this
:class:`.NamespaceManager`.
e.g. for a file-based namespace, this would remove
all the files.
"""
self.do_remove()
class OpenResourceNamespaceManager(NamespaceManager):
"""A NamespaceManager where read/write operations require opening/
closing of a resource which is possibly mutexed.
"""
def __init__(self, namespace):
NamespaceManager.__init__(self, namespace)
self.access_lock = self.get_access_lock()
self.openers = 0
self.mutex = _threading.Lock()
def get_access_lock(self):
raise NotImplementedError()
def do_open(self, flags, replace):
raise NotImplementedError()
def do_close(self):
raise NotImplementedError()
def acquire_read_lock(self):
self.access_lock.acquire_read_lock()
try:
self.open('r', checkcount=True)
except:
self.access_lock.release_read_lock()
raise
def release_read_lock(self):
try:
self.close(checkcount=True)
finally:
self.access_lock.release_read_lock()
def acquire_write_lock(self, wait=True, replace=False):
r = self.access_lock.acquire_write_lock(wait)
try:
if (wait or r):
self.open('c', checkcount=True, replace=replace)
return r
except:
self.access_lock.release_write_lock()
raise
def release_write_lock(self):
try:
self.close(checkcount=True)
finally:
self.access_lock.release_write_lock()
def open(self, flags, checkcount=False, replace=False):
self.mutex.acquire()
try:
if checkcount:
if self.openers == 0:
self.do_open(flags, replace)
self.openers += 1
else:
self.do_open(flags, replace)
self.openers = 1
finally:
self.mutex.release()
def close(self, checkcount=False):
self.mutex.acquire()
try:
if checkcount:
self.openers -= 1
if self.openers == 0:
self.do_close()
else:
if self.openers > 0:
self.do_close()
self.openers = 0
finally:
self.mutex.release()
def remove(self):
self.access_lock.acquire_write_lock()
try:
self.close(checkcount=False)
self.do_remove()
finally:
self.access_lock.release_write_lock()
class Value(object):
"""Implements synchronization, expiration, and value-creation logic
for a single value stored in a :class:`.NamespaceManager`.
"""
__slots__ = 'key', 'createfunc', 'expiretime', 'expire_argument', 'starttime', 'storedtime',\
'namespace'
def __init__(self, key, namespace, createfunc=None, expiretime=None, starttime=None):
self.key = key
self.createfunc = createfunc
self.expire_argument = expiretime
self.starttime = starttime
self.storedtime = -1
self.namespace = namespace
def has_value(self):
"""return true if the container has a value stored.
This is regardless of it being expired or not.
"""
self.namespace.acquire_read_lock()
try:
return self.key in self.namespace
finally:
self.namespace.release_read_lock()
def can_have_value(self):
return self.has_current_value() or self.createfunc is not None
def has_current_value(self):
self.namespace.acquire_read_lock()
try:
has_value = self.key in self.namespace
if has_value:
try:
stored, expired, value = self._get_value()
return not self._is_expired(stored, expired)
except KeyError:
pass
return False
finally:
self.namespace.release_read_lock()
def _is_expired(self, storedtime, expiretime):
"""Return true if this container's value is expired."""
return (
(
self.starttime is not None and
storedtime < self.starttime
)
or
(
expiretime is not None and
time.time() >= expiretime + storedtime
)
)
def get_value(self):
self.namespace.acquire_read_lock()
try:
has_value = self.has_value()
if has_value:
try:
stored, expired, value = self._get_value()
if not self._is_expired(stored, expired):
return value
except KeyError:
# guard against un-mutexed backends raising KeyError
has_value = False
if not self.createfunc:
raise KeyError(self.key)
finally:
self.namespace.release_read_lock()
has_createlock = False
creation_lock = self.namespace.get_creation_lock(self.key)
if has_value:
if not creation_lock.acquire(wait=False):
debug("get_value returning old value while new one is created")
return value
else:
debug("lock_creatfunc (didnt wait)")
has_createlock = True
if not has_createlock:
debug("lock_createfunc (waiting)")
creation_lock.acquire()
debug("lock_createfunc (waited)")
try:
# see if someone created the value already
self.namespace.acquire_read_lock()
try:
if self.has_value():
try:
stored, expired, value = self._get_value()
if not self._is_expired(stored, expired):
return value
except KeyError:
# guard against un-mutexed backends raising KeyError
pass
finally:
self.namespace.release_read_lock()
debug("get_value creating new value")
v = self.createfunc()
self.set_value(v)
return v
finally:
creation_lock.release()
debug("released create lock")
def _get_value(self):
value = self.namespace[self.key]
try:
stored, expired, value = value
except ValueError:
if not len(value) == 2:
raise
# Old format: upgrade
stored, value = value
expired = self.expire_argument
debug("get_value upgrading time %r expire time %r", stored, self.expire_argument)
self.namespace.release_read_lock()
self.set_value(value, stored)
self.namespace.acquire_read_lock()
except TypeError:
# occurs when the value is None. memcached
# may yank the rug from under us in which case
# that's the result
raise KeyError(self.key)
return stored, expired, value
def set_value(self, value, storedtime=None):
self.namespace.acquire_write_lock()
try:
if storedtime is None:
storedtime = time.time()
debug("set_value stored time %r expire time %r", storedtime, self.expire_argument)
self.namespace.set_value(self.key, (storedtime, self.expire_argument, value),
expiretime=self.expire_argument)
finally:
self.namespace.release_write_lock()
def clear_value(self):
self.namespace.acquire_write_lock()
try:
debug("clear_value")
if self.key in self.namespace:
try:
del self.namespace[self.key]
except KeyError:
# guard against un-mutexed backends raising KeyError
pass
self.storedtime = -1
finally:
self.namespace.release_write_lock()
class AbstractDictionaryNSManager(NamespaceManager):
"""A subclassable NamespaceManager that places data in a dictionary.
Subclasses should provide a "dictionary" attribute or descriptor
which returns a dict-like object. The dictionary will store keys
that are local to the "namespace" attribute of this manager, so
ensure that the dictionary will not be used by any other namespace.
e.g.::
import collections
cached_data = collections.defaultdict(dict)
class MyDictionaryManager(AbstractDictionaryNSManager):
def __init__(self, namespace):
AbstractDictionaryNSManager.__init__(self, namespace)
self.dictionary = cached_data[self.namespace]
The above stores data in a global dictionary called "cached_data",
which is structured as a dictionary of dictionaries, keyed
first on namespace name to a sub-dictionary, then on actual
cache key to value.
"""
def get_creation_lock(self, key):
return NameLock(
identifier="memorynamespace/funclock/%s/%s" %
(self.namespace, key),
reentrant=True
)
def __getitem__(self, key):
return self.dictionary[key]
def __contains__(self, key):
return self.dictionary.__contains__(key)
def has_key(self, key):
return self.dictionary.__contains__(key)
def __setitem__(self, key, value):
self.dictionary[key] = value
def __delitem__(self, key):
del self.dictionary[key]
def do_remove(self):
self.dictionary.clear()
def keys(self):
return self.dictionary.keys()
class MemoryNamespaceManager(AbstractDictionaryNSManager):
""":class:`.NamespaceManager` that uses a Python dictionary for storage."""
namespaces = util.SyncDict()
def __init__(self, namespace, **kwargs):
AbstractDictionaryNSManager.__init__(self, namespace)
self.dictionary = MemoryNamespaceManager.\
namespaces.get(self.namespace, dict)
class DBMNamespaceManager(OpenResourceNamespaceManager):
""":class:`.NamespaceManager` that uses ``dbm`` files for storage."""
def __init__(self, namespace, dbmmodule=None, data_dir=None,
dbm_dir=None, lock_dir=None,
digest_filenames=True, **kwargs):
self.digest_filenames = digest_filenames
if not dbm_dir and not data_dir:
raise MissingCacheParameter("data_dir or dbm_dir is required")
elif dbm_dir:
self.dbm_dir = dbm_dir
else:
self.dbm_dir = data_dir + "/container_dbm"
util.verify_directory(self.dbm_dir)
if not lock_dir and not data_dir:
raise MissingCacheParameter("data_dir or lock_dir is required")
elif lock_dir:
self.lock_dir = lock_dir
else:
self.lock_dir = data_dir + "/container_dbm_lock"
util.verify_directory(self.lock_dir)
self.dbmmodule = dbmmodule or anydbm
self.dbm = None
OpenResourceNamespaceManager.__init__(self, namespace)
self.file = util.encoded_path(root=self.dbm_dir,
identifiers=[self.namespace],
extension='.dbm',
digest_filenames=self.digest_filenames)
debug("data file %s", self.file)
self._checkfile()
def get_access_lock(self):
return file_synchronizer(identifier=self.namespace,
lock_dir=self.lock_dir)
def get_creation_lock(self, key):
return file_synchronizer(
identifier="dbmcontainer/funclock/%s/%s" % (
self.namespace, key
),
lock_dir=self.lock_dir
)
def file_exists(self, file):
if os.access(file, os.F_OK):
return True
else:
for ext in ('db', 'dat', 'pag', 'dir'):
if os.access(file + os.extsep + ext, os.F_OK):
return True
return False
def _ensuredir(self, filename):
dirname = os.path.dirname(filename)
if not os.path.exists(dirname):
util.verify_directory(dirname)
def _checkfile(self):
if not self.file_exists(self.file):
self._ensuredir(self.file)
g = self.dbmmodule.open(self.file, 'c')
g.close()
def get_filenames(self):
list = []
if os.access(self.file, os.F_OK):
list.append(self.file)
for ext in ('pag', 'dir', 'db', 'dat'):
if os.access(self.file + os.extsep + ext, os.F_OK):
list.append(self.file + os.extsep + ext)
return list
def do_open(self, flags, replace):
debug("opening dbm file %s", self.file)
try:
self.dbm = self.dbmmodule.open(self.file, flags)
except:
self._checkfile()
self.dbm = self.dbmmodule.open(self.file, flags)
def do_close(self):
if self.dbm is not None:
debug("closing dbm file %s", self.file)
self.dbm.close()
def do_remove(self):
for f in self.get_filenames():
os.remove(f)
def __getitem__(self, key):
return pickle.loads(self.dbm[key])
def __contains__(self, key):
if PYVER == (3, 2):
# Looks like this is a bug that got solved in PY3.3 and PY3.4
# http://bugs.python.org/issue19288
if isinstance(key, unicode_text):
key = key.encode('UTF-8')
return key in self.dbm
def __setitem__(self, key, value):
self.dbm[key] = pickle.dumps(value)
def __delitem__(self, key):
del self.dbm[key]
def keys(self):
return self.dbm.keys()
class FileNamespaceManager(OpenResourceNamespaceManager):
""":class:`.NamespaceManager` that uses binary files for storage.
Each namespace is implemented as a single file storing a
dictionary of key/value pairs, serialized using the Python
``pickle`` module.
"""
def __init__(self, namespace, data_dir=None, file_dir=None, lock_dir=None,
digest_filenames=True, **kwargs):
self.digest_filenames = digest_filenames
if not file_dir and not data_dir:
raise MissingCacheParameter("data_dir or file_dir is required")
elif file_dir:
self.file_dir = file_dir
else:
self.file_dir = data_dir + "/container_file"
util.verify_directory(self.file_dir)
if not lock_dir and not data_dir:
raise MissingCacheParameter("data_dir or lock_dir is required")
elif lock_dir:
self.lock_dir = lock_dir
else:
self.lock_dir = data_dir + "/container_file_lock"
util.verify_directory(self.lock_dir)
OpenResourceNamespaceManager.__init__(self, namespace)
self.file = util.encoded_path(root=self.file_dir,
identifiers=[self.namespace],
extension='.cache',
digest_filenames=self.digest_filenames)
self.hash = {}
debug("data file %s", self.file)
def get_access_lock(self):
return file_synchronizer(identifier=self.namespace,
lock_dir=self.lock_dir)
def get_creation_lock(self, key):
return file_synchronizer(
identifier="dbmcontainer/funclock/%s/%s" % (
self.namespace, key
),
lock_dir=self.lock_dir
)
def file_exists(self, file):
return os.access(file, os.F_OK)
def do_open(self, flags, replace):
if not replace and self.file_exists(self.file):
try:
with open(self.file, 'rb') as fh:
self.hash = pickle.load(fh)
except IOError as e:
# Ignore EACCES and ENOENT as it just means we are no longer
# able to access the file or that it no longer exists
if e.errno not in [errno.EACCES, errno.ENOENT]:
raise
self.flags = flags
def do_close(self):
if self.flags == 'c' or self.flags == 'w':
pickled = pickle.dumps(self.hash)
util.safe_write(self.file, pickled)
self.hash = {}
self.flags = None
def do_remove(self):
try:
os.remove(self.file)
except OSError:
# for instance, because we haven't yet used this cache,
# but client code has asked for a clear() operation...
pass
self.hash = {}
def __getitem__(self, key):
return self.hash[key]
def __contains__(self, key):
return key in self.hash
def __setitem__(self, key, value):
self.hash[key] = value
def __delitem__(self, key):
del self.hash[key]
def keys(self):
return self.hash.keys()
#### legacy stuff to support the old "Container" class interface
namespace_classes = {}
ContainerContext = dict
class ContainerMeta(type):
def __init__(cls, classname, bases, dict_):
namespace_classes[cls] = cls.namespace_class
return type.__init__(cls, classname, bases, dict_)
def __call__(self, key, context, namespace, createfunc=None,
expiretime=None, starttime=None, **kwargs):
if namespace in context:
ns = context[namespace]
else:
nscls = namespace_classes[self]
context[namespace] = ns = nscls(namespace, **kwargs)
return Value(key, ns, createfunc=createfunc,
expiretime=expiretime, starttime=starttime)
@add_metaclass(ContainerMeta)
class Container(object):
"""Implements synchronization and value-creation logic
for a 'value' stored in a :class:`.NamespaceManager`.
:class:`.Container` and its subclasses are deprecated. The
:class:`.Value` class is now used for this purpose.
"""
namespace_class = NamespaceManager
class FileContainer(Container):
namespace_class = FileNamespaceManager
class MemoryContainer(Container):
namespace_class = MemoryNamespaceManager
class DBMContainer(Container):
namespace_class = DBMNamespaceManager
DbmContainer = DBMContainer

View file

@ -1,29 +0,0 @@
from beaker._compat import string_type
# (c) 2005 Ian Bicking and contributors; written for Paste (http://pythonpaste.org)
# Licensed under the MIT license: http://www.opensource.org/licenses/mit-license.php
def asbool(obj):
if isinstance(obj, string_type):
obj = obj.strip().lower()
if obj in ['true', 'yes', 'on', 'y', 't', '1']:
return True
elif obj in ['false', 'no', 'off', 'n', 'f', '0']:
return False
else:
raise ValueError(
"String is not true/false: %r" % obj)
return bool(obj)
def aslist(obj, sep=None, strip=True):
if isinstance(obj, string_type):
lst = obj.split(sep)
if strip:
lst = [v.strip() for v in lst]
return lst
elif isinstance(obj, (list, tuple)):
return obj
elif obj is None:
return []
else:
return [obj]

View file

@ -1,72 +0,0 @@
import sys
from ._compat import http_cookies
# Some versions of Python 2.7 and later won't need this encoding bug fix:
_cookie_encodes_correctly = http_cookies.SimpleCookie().value_encode(';') == (';', '"\\073"')
# Cookie pickling bug is fixed in Python 2.7.9 and Python 3.4.3+
# http://bugs.python.org/issue22775
cookie_pickles_properly = (
(sys.version_info[:2] == (2, 7) and sys.version_info >= (2, 7, 9)) or
sys.version_info >= (3, 4, 3)
)
# Add support for the SameSite attribute (obsolete when PY37 is unsupported).
http_cookies.Morsel._reserved.setdefault('samesite', 'SameSite')
# Adapted from Django.http.cookies and always enabled the bad_cookies
# behaviour to cope with any invalid cookie key while keeping around
# the session.
class SimpleCookie(http_cookies.SimpleCookie):
if not cookie_pickles_properly:
def __setitem__(self, key, value):
# Apply the fix from http://bugs.python.org/issue22775 where
# it's not fixed in Python itself
if isinstance(value, http_cookies.Morsel):
# allow assignment of constructed Morsels (e.g. for pickling)
dict.__setitem__(self, key, value)
else:
super(SimpleCookie, self).__setitem__(key, value)
if not _cookie_encodes_correctly:
def value_encode(self, val):
# Some browsers do not support quoted-string from RFC 2109,
# including some versions of Safari and Internet Explorer.
# These browsers split on ';', and some versions of Safari
# are known to split on ', '. Therefore, we encode ';' and ','
# SimpleCookie already does the hard work of encoding and decoding.
# It uses octal sequences like '\\012' for newline etc.
# and non-ASCII chars. We just make use of this mechanism, to
# avoid introducing two encoding schemes which would be confusing
# and especially awkward for javascript.
# NB, contrary to Python docs, value_encode returns a tuple containing
# (real val, encoded_val)
val, encoded = super(SimpleCookie, self).value_encode(val)
encoded = encoded.replace(";", "\\073").replace(",", "\\054")
# If encoded now contains any quoted chars, we need double quotes
# around the whole string.
if "\\" in encoded and not encoded.startswith('"'):
encoded = '"' + encoded + '"'
return val, encoded
def load(self, rawdata):
self.bad_cookies = set()
super(SimpleCookie, self).load(rawdata)
for key in self.bad_cookies:
del self[key]
# override private __set() method:
# (needed for using our Morsel, and for laxness with CookieError
def _BaseCookie__set(self, key, real_value, coded_value):
try:
super(SimpleCookie, self)._BaseCookie__set(key, real_value, coded_value)
except http_cookies.CookieError:
if not hasattr(self, 'bad_cookies'):
self.bad_cookies = set()
self.bad_cookies.add(key)
dict.__setitem__(self, key, http_cookies.Morsel())

View file

@ -1,83 +0,0 @@
from .._compat import JYTHON
from beaker.crypto.pbkdf2 import pbkdf2
from beaker.crypto.util import hmac, sha1, hmac_sha1, md5
from beaker import util
from beaker.exceptions import InvalidCryptoBackendError
keyLength = None
DEFAULT_NONCE_BITS = 128
CRYPTO_MODULES = {}
def load_default_module():
""" Load the default crypto module
"""
if JYTHON:
try:
from beaker.crypto import jcecrypto
return jcecrypto
except ImportError:
pass
else:
try:
from beaker.crypto import nsscrypto
return nsscrypto
except ImportError:
try:
from beaker.crypto import pycrypto
return pycrypto
except ImportError:
pass
from beaker.crypto import noencryption
return noencryption
def register_crypto_module(name, mod):
"""
Register the given module under the name given.
"""
CRYPTO_MODULES[name] = mod
def get_crypto_module(name):
"""
Get the active crypto module for this name
"""
if name not in CRYPTO_MODULES:
if name == 'default':
register_crypto_module('default', load_default_module())
elif name == 'nss':
from beaker.crypto import nsscrypto
register_crypto_module(name, nsscrypto)
elif name == 'pycrypto':
from beaker.crypto import pycrypto
register_crypto_module(name, pycrypto)
elif name == 'cryptography':
from beaker.crypto import pyca_cryptography
register_crypto_module(name, pyca_cryptography)
else:
raise InvalidCryptoBackendError(
"No crypto backend with name '%s' is registered." % name)
return CRYPTO_MODULES[name]
def generateCryptoKeys(master_key, salt, iterations, keylen):
# NB: We XOR parts of the keystream into the randomly-generated parts, just
# in case os.urandom() isn't as random as it should be. Note that if
# os.urandom() returns truly random data, this will have no effect on the
# overall security.
return pbkdf2(master_key, salt, iterations=iterations, dklen=keylen)
def get_nonce_size(number_of_bits):
if number_of_bits % 8:
raise ValueError('Nonce complexity currently supports multiples of 8')
bytes = number_of_bits // 8
b64bytes = ((4 * bytes // 3) + 3) & ~3
return bytes, b64bytes

View file

@ -1,41 +0,0 @@
"""
Encryption module that uses the Java Cryptography Extensions (JCE).
Note that in default installations of the Java Runtime Environment, the
maximum key length is limited to 128 bits due to US export
restrictions. This makes the generated keys incompatible with the ones
generated by pycryptopp, which has no such restrictions. To fix this,
download the "Unlimited Strength Jurisdiction Policy Files" from Sun,
which will allow encryption using 256 bit AES keys.
"""
from warnings import warn
from javax.crypto import Cipher
from javax.crypto.spec import SecretKeySpec, IvParameterSpec
import jarray
# Initialization vector filled with zeros
_iv = IvParameterSpec(jarray.zeros(16, 'b'))
def aesEncrypt(data, key):
cipher = Cipher.getInstance('AES/CTR/NoPadding')
skeySpec = SecretKeySpec(key, 'AES')
cipher.init(Cipher.ENCRYPT_MODE, skeySpec, _iv)
return cipher.doFinal(data).tostring()
# magic.
aesDecrypt = aesEncrypt
has_aes = True
def getKeyLength():
maxlen = Cipher.getMaxAllowedKeyLength('AES/CTR/NoPadding')
return min(maxlen, 256) / 8
if getKeyLength() < 32:
warn('Crypto implementation only supports key lengths up to %d bits. '
'Generated session cookies may be incompatible with other '
'environments' % (getKeyLength() * 8))

View file

@ -1,12 +0,0 @@
"""Encryption module that does nothing"""
def aesEncrypt(data, key):
return data
def aesDecrypt(data, key):
return data
has_aes = False
def getKeyLength():
return 32

View file

@ -1,47 +0,0 @@
"""Encryption module that uses nsscrypto"""
import nss.nss
nss.nss.nss_init_nodb()
# Apparently the rest of beaker doesn't care about the particluar cipher,
# mode and padding used.
# NOTE: A constant IV!!! This is only secure if the KEY is never reused!!!
_mech = nss.nss.CKM_AES_CBC_PAD
_iv = '\0' * nss.nss.get_iv_length(_mech)
def aesEncrypt(data, key):
slot = nss.nss.get_best_slot(_mech)
key_obj = nss.nss.import_sym_key(slot, _mech, nss.nss.PK11_OriginGenerated,
nss.nss.CKA_ENCRYPT, nss.nss.SecItem(key))
param = nss.nss.param_from_iv(_mech, nss.nss.SecItem(_iv))
ctx = nss.nss.create_context_by_sym_key(_mech, nss.nss.CKA_ENCRYPT, key_obj,
param)
l1 = ctx.cipher_op(data)
# Yes, DIGEST. This needs fixing in NSS, but apparently nobody (including
# me :( ) cares enough.
l2 = ctx.digest_final()
return l1 + l2
def aesDecrypt(data, key):
slot = nss.nss.get_best_slot(_mech)
key_obj = nss.nss.import_sym_key(slot, _mech, nss.nss.PK11_OriginGenerated,
nss.nss.CKA_DECRYPT, nss.nss.SecItem(key))
param = nss.nss.param_from_iv(_mech, nss.nss.SecItem(_iv))
ctx = nss.nss.create_context_by_sym_key(_mech, nss.nss.CKA_DECRYPT, key_obj,
param)
l1 = ctx.cipher_op(data)
# Yes, DIGEST. This needs fixing in NSS, but apparently nobody (including
# me :( ) cares enough.
l2 = ctx.digest_final()
return l1 + l2
has_aes = True
def getKeyLength():
return 32

View file

@ -1,94 +0,0 @@
"""
PBKDF2 Implementation adapted from django.utils.crypto.
This is used to generate the encryption key for enciphered sessions.
"""
from beaker._compat import bytes_, xrange_
import hmac
import struct
import hashlib
import binascii
def _bin_to_long(x):
"""Convert a binary string into a long integer"""
return int(binascii.hexlify(x), 16)
def _long_to_bin(x, hex_format_string):
"""
Convert a long integer into a binary string.
hex_format_string is like "%020x" for padding 10 characters.
"""
return binascii.unhexlify((hex_format_string % x).encode('ascii'))
if hasattr(hashlib, "pbkdf2_hmac"):
def pbkdf2(password, salt, iterations, dklen=0, digest=None):
"""
Implements PBKDF2 using the stdlib. This is used in Python 2.7.8+ and 3.4+.
HMAC+SHA256 is used as the default pseudo random function.
As of 2014, 100,000 iterations was the recommended default which took
100ms on a 2.7Ghz Intel i7 with an optimized implementation. This is
probably the bare minimum for security given 1000 iterations was
recommended in 2001.
"""
if digest is None:
digest = hashlib.sha1
if not dklen:
dklen = None
password = bytes_(password)
salt = bytes_(salt)
return hashlib.pbkdf2_hmac(
digest().name, password, salt, iterations, dklen)
else:
def pbkdf2(password, salt, iterations, dklen=0, digest=None):
"""
Implements PBKDF2 as defined in RFC 2898, section 5.2
HMAC+SHA256 is used as the default pseudo random function.
As of 2014, 100,000 iterations was the recommended default which took
100ms on a 2.7Ghz Intel i7 with an optimized implementation. This is
probably the bare minimum for security given 1000 iterations was
recommended in 2001. This code is very well optimized for CPython and
is about five times slower than OpenSSL's implementation.
"""
assert iterations > 0
if not digest:
digest = hashlib.sha1
password = bytes_(password)
salt = bytes_(salt)
hlen = digest().digest_size
if not dklen:
dklen = hlen
if dklen > (2 ** 32 - 1) * hlen:
raise OverflowError('dklen too big')
l = -(-dklen // hlen)
r = dklen - (l - 1) * hlen
hex_format_string = "%%0%ix" % (hlen * 2)
inner, outer = digest(), digest()
if len(password) > inner.block_size:
password = digest(password).digest()
password += b'\x00' * (inner.block_size - len(password))
inner.update(password.translate(hmac.trans_36))
outer.update(password.translate(hmac.trans_5C))
def F(i):
u = salt + struct.pack(b'>I', i)
result = 0
for j in xrange_(int(iterations)):
dig1, dig2 = inner.copy(), outer.copy()
dig1.update(u)
dig2.update(dig1.digest())
u = dig2.digest()
result ^= _bin_to_long(u)
return _long_to_bin(result, hex_format_string)
T = [F(x) for x in xrange_(1, l)]
return b''.join(T) + F(l)[:r]

View file

@ -1,52 +0,0 @@
"""Encryption module that uses pyca/cryptography"""
import os
import json
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.ciphers import (
Cipher, algorithms, modes
)
def aesEncrypt(data, key):
# Generate a random 96-bit IV.
iv = os.urandom(12)
# Construct an AES-GCM Cipher object with the given key and a
# randomly generated IV.
encryptor = Cipher(
algorithms.AES(key),
modes.GCM(iv),
backend=default_backend()
).encryptor()
# Encrypt the plaintext and get the associated ciphertext.
# GCM does not require padding.
ciphertext = encryptor.update(data) + encryptor.finalize()
return iv + encryptor.tag + ciphertext
def aesDecrypt(data, key):
iv = data[:12]
tag = data[12:28]
ciphertext = data[28:]
# Construct a Cipher object, with the key, iv, and additionally the
# GCM tag used for authenticating the message.
decryptor = Cipher(
algorithms.AES(key),
modes.GCM(iv, tag),
backend=default_backend()
).decryptor()
# Decryption gets us the authenticated plaintext.
# If the tag does not match an InvalidTag exception will be raised.
return decryptor.update(ciphertext) + decryptor.finalize()
has_aes = True
def getKeyLength():
return 32

View file

@ -1,34 +0,0 @@
"""Encryption module that uses pycryptopp or pycrypto"""
try:
# Pycryptopp is preferred over Crypto because Crypto has had
# various periods of not being maintained, and pycryptopp uses
# the Crypto++ library which is generally considered the 'gold standard'
# of crypto implementations
from pycryptopp.cipher import aes
def aesEncrypt(data, key):
cipher = aes.AES(key)
return cipher.process(data)
# magic.
aesDecrypt = aesEncrypt
except ImportError:
from Crypto.Cipher import AES
from Crypto.Util import Counter
def aesEncrypt(data, key):
cipher = AES.new(key, AES.MODE_CTR,
counter=Counter.new(128, initial_value=0))
return cipher.encrypt(data)
def aesDecrypt(data, key):
cipher = AES.new(key, AES.MODE_CTR,
counter=Counter.new(128, initial_value=0))
return cipher.decrypt(data)
has_aes = True
def getKeyLength():
return 32

View file

@ -1,16 +0,0 @@
from hashlib import md5
try:
# Use PyCrypto (if available)
from Crypto.Hash import HMAC as hmac, SHA as hmac_sha1
sha1 = hmac_sha1.new
except ImportError:
# PyCrypto not available. Use the Python standard library.
import hmac
# NOTE: We have to use the callable with hashlib (hashlib.sha1),
# otherwise hmac only accepts the sha module object itself
from hashlib import sha1
hmac_sha1 = sha1

View file

@ -1,29 +0,0 @@
"""Beaker exception classes"""
class BeakerException(Exception):
pass
class BeakerWarning(RuntimeWarning):
"""Issued at runtime."""
class CreationAbortedError(Exception):
"""Deprecated."""
class InvalidCacheBackendError(BeakerException, ImportError):
pass
class MissingCacheParameter(BeakerException):
pass
class LockError(BeakerException):
pass
class InvalidCryptoBackendError(BeakerException):
pass

View file

@ -1,180 +0,0 @@
from beaker._compat import pickle
import logging
import pickle
from datetime import datetime
from beaker.container import OpenResourceNamespaceManager, Container
from beaker.exceptions import InvalidCacheBackendError, MissingCacheParameter
from beaker.synchronization import file_synchronizer, null_synchronizer
from beaker.util import verify_directory, SyncDict
log = logging.getLogger(__name__)
sa = None
pool = None
types = None
class DatabaseNamespaceManager(OpenResourceNamespaceManager):
metadatas = SyncDict()
tables = SyncDict()
@classmethod
def _init_dependencies(cls):
global sa, pool, types
if sa is not None:
return
try:
import sqlalchemy as sa
import sqlalchemy.pool as pool
from sqlalchemy import types
except ImportError:
raise InvalidCacheBackendError("Database cache backend requires "
"the 'sqlalchemy' library")
def __init__(self, namespace, url=None, sa_opts=None, optimistic=False,
table_name='beaker_cache', data_dir=None, lock_dir=None,
schema_name=None, **params):
"""Creates a database namespace manager
``url``
SQLAlchemy compliant db url
``sa_opts``
A dictionary of SQLAlchemy keyword options to initialize the engine
with.
``optimistic``
Use optimistic session locking, note that this will result in an
additional select when updating a cache value to compare version
numbers.
``table_name``
The table name to use in the database for the cache.
``schema_name``
The schema name to use in the database for the cache.
"""
OpenResourceNamespaceManager.__init__(self, namespace)
if sa_opts is None:
sa_opts = {}
self.lock_dir = None
if lock_dir:
self.lock_dir = lock_dir
elif data_dir:
self.lock_dir = data_dir + "/container_db_lock"
if self.lock_dir:
verify_directory(self.lock_dir)
# Check to see if the table's been created before
url = url or sa_opts['sa.url']
table_key = url + table_name
def make_cache():
# Check to see if we have a connection pool open already
meta_key = url + table_name
def make_meta():
# SQLAlchemy pops the url, this ensures it sticks around
# later
sa_opts['sa.url'] = url
engine = sa.engine_from_config(sa_opts, 'sa.')
meta = sa.MetaData()
meta.bind = engine
return meta
meta = DatabaseNamespaceManager.metadatas.get(meta_key, make_meta)
# Create the table object and cache it now
cache = sa.Table(table_name, meta,
sa.Column('id', types.Integer, primary_key=True),
sa.Column('namespace', types.String(255), nullable=False),
sa.Column('accessed', types.DateTime, nullable=False),
sa.Column('created', types.DateTime, nullable=False),
sa.Column('data', types.PickleType, nullable=False),
sa.UniqueConstraint('namespace'),
schema=schema_name if schema_name else meta.schema
)
cache.create(checkfirst=True)
return cache
self.hash = {}
self._is_new = False
self.loaded = False
self.cache = DatabaseNamespaceManager.tables.get(table_key, make_cache)
def get_access_lock(self):
return null_synchronizer()
def get_creation_lock(self, key):
return file_synchronizer(
identifier="databasecontainer/funclock/%s/%s" % (
self.namespace, key
),
lock_dir=self.lock_dir)
def do_open(self, flags, replace):
# If we already loaded the data, don't bother loading it again
if self.loaded:
self.flags = flags
return
cache = self.cache
result_proxy = sa.select([cache.c.data],
cache.c.namespace == self.namespace
).execute()
result = result_proxy.fetchone()
result_proxy.close()
if not result:
self._is_new = True
self.hash = {}
else:
self._is_new = False
try:
self.hash = result['data']
except (IOError, OSError, EOFError, pickle.PickleError,
pickle.PickleError):
log.debug("Couln't load pickle data, creating new storage")
self.hash = {}
self._is_new = True
self.flags = flags
self.loaded = True
def do_close(self):
if self.flags is not None and (self.flags == 'c' or self.flags == 'w'):
cache = self.cache
if self._is_new:
cache.insert().execute(namespace=self.namespace, data=self.hash,
accessed=datetime.now(),
created=datetime.now())
self._is_new = False
else:
cache.update(cache.c.namespace == self.namespace).execute(
data=self.hash, accessed=datetime.now())
self.flags = None
def do_remove(self):
cache = self.cache
cache.delete(cache.c.namespace == self.namespace).execute()
self.hash = {}
# We can retain the fact that we did a load attempt, but since the
# file is gone this will be a new namespace should it be saved.
self._is_new = True
def __getitem__(self, key):
return self.hash[key]
def __contains__(self, key):
return key in self.hash
def __setitem__(self, key, value):
self.hash[key] = value
def __delitem__(self, key):
del self.hash[key]
def keys(self):
return self.hash.keys()
class DatabaseContainer(Container):
namespace_manager = DatabaseNamespaceManager

View file

@ -1,122 +0,0 @@
from beaker._compat import pickle
import logging
from datetime import datetime
from beaker.container import OpenResourceNamespaceManager, Container
from beaker.exceptions import InvalidCacheBackendError
from beaker.synchronization import null_synchronizer
log = logging.getLogger(__name__)
db = None
class GoogleNamespaceManager(OpenResourceNamespaceManager):
tables = {}
@classmethod
def _init_dependencies(cls):
global db
if db is not None:
return
try:
db = __import__('google.appengine.ext.db').appengine.ext.db
except ImportError:
raise InvalidCacheBackendError("Datastore cache backend requires the "
"'google.appengine.ext' library")
def __init__(self, namespace, table_name='beaker_cache', **params):
"""Creates a datastore namespace manager"""
OpenResourceNamespaceManager.__init__(self, namespace)
def make_cache():
table_dict = dict(created=db.DateTimeProperty(),
accessed=db.DateTimeProperty(),
data=db.BlobProperty())
table = type(table_name, (db.Model,), table_dict)
return table
self.table_name = table_name
self.cache = GoogleNamespaceManager.tables.setdefault(table_name, make_cache())
self.hash = {}
self._is_new = False
self.loaded = False
self.log_debug = logging.DEBUG >= log.getEffectiveLevel()
# Google wants namespaces to start with letters, change the namespace
# to start with a letter
self.namespace = 'p%s' % self.namespace
def get_access_lock(self):
return null_synchronizer()
def get_creation_lock(self, key):
# this is weird, should probably be present
return null_synchronizer()
def do_open(self, flags, replace):
# If we already loaded the data, don't bother loading it again
if self.loaded:
self.flags = flags
return
item = self.cache.get_by_key_name(self.namespace)
if not item:
self._is_new = True
self.hash = {}
else:
self._is_new = False
try:
self.hash = pickle.loads(str(item.data))
except (IOError, OSError, EOFError, pickle.PickleError):
if self.log_debug:
log.debug("Couln't load pickle data, creating new storage")
self.hash = {}
self._is_new = True
self.flags = flags
self.loaded = True
def do_close(self):
if self.flags is not None and (self.flags == 'c' or self.flags == 'w'):
if self._is_new:
item = self.cache(key_name=self.namespace)
item.data = pickle.dumps(self.hash)
item.created = datetime.now()
item.accessed = datetime.now()
item.put()
self._is_new = False
else:
item = self.cache.get_by_key_name(self.namespace)
item.data = pickle.dumps(self.hash)
item.accessed = datetime.now()
item.put()
self.flags = None
def do_remove(self):
item = self.cache.get_by_key_name(self.namespace)
item.delete()
self.hash = {}
# We can retain the fact that we did a load attempt, but since the
# file is gone this will be a new namespace should it be saved.
self._is_new = True
def __getitem__(self, key):
return self.hash[key]
def __contains__(self, key):
return key in self.hash
def __setitem__(self, key, value):
self.hash[key] = value
def __delitem__(self, key):
del self.hash[key]
def keys(self):
return self.hash.keys()
class GoogleContainer(Container):
namespace_class = GoogleNamespaceManager

View file

@ -1,218 +0,0 @@
from .._compat import PY2
from beaker.container import NamespaceManager, Container
from beaker.crypto.util import sha1
from beaker.exceptions import InvalidCacheBackendError, MissingCacheParameter
from beaker.synchronization import file_synchronizer
from beaker.util import verify_directory, SyncDict, parse_memcached_behaviors
import warnings
MAX_KEY_LENGTH = 250
_client_libs = {}
def _load_client(name='auto'):
if name in _client_libs:
return _client_libs[name]
def _pylibmc():
global pylibmc
import pylibmc
return pylibmc
def _cmemcache():
global cmemcache
import cmemcache
warnings.warn("cmemcache is known to have serious "
"concurrency issues; consider using 'memcache' "
"or 'pylibmc'")
return cmemcache
def _memcache():
global memcache
import memcache
return memcache
def _bmemcached():
global bmemcached
import bmemcached
return bmemcached
def _auto():
for _client in (_pylibmc, _cmemcache, _memcache, _bmemcached):
try:
return _client()
except ImportError:
pass
else:
raise InvalidCacheBackendError(
"Memcached cache backend requires one "
"of: 'pylibmc' or 'memcache' to be installed.")
clients = {
'pylibmc': _pylibmc,
'cmemcache': _cmemcache,
'memcache': _memcache,
'bmemcached': _bmemcached,
'auto': _auto
}
_client_libs[name] = clib = clients[name]()
return clib
def _is_configured_for_pylibmc(memcache_module_config, memcache_client):
return memcache_module_config == 'pylibmc' or \
memcache_client.__name__.startswith('pylibmc')
class MemcachedNamespaceManager(NamespaceManager):
"""Provides the :class:`.NamespaceManager` API over a memcache client library."""
clients = SyncDict()
def __new__(cls, *args, **kw):
memcache_module = kw.pop('memcache_module', 'auto')
memcache_client = _load_client(memcache_module)
if _is_configured_for_pylibmc(memcache_module, memcache_client):
return object.__new__(PyLibMCNamespaceManager)
else:
return object.__new__(MemcachedNamespaceManager)
def __init__(self, namespace, url,
memcache_module='auto',
data_dir=None, lock_dir=None,
**kw):
NamespaceManager.__init__(self, namespace)
_memcache_module = _client_libs[memcache_module]
if not url:
raise MissingCacheParameter("url is required")
self.lock_dir = None
if lock_dir:
self.lock_dir = lock_dir
elif data_dir:
self.lock_dir = data_dir + "/container_mcd_lock"
if self.lock_dir:
verify_directory(self.lock_dir)
# Check for pylibmc namespace manager, in which case client will be
# instantiated by subclass __init__, to handle behavior passing to the
# pylibmc client
if not _is_configured_for_pylibmc(memcache_module, _memcache_module):
self.mc = MemcachedNamespaceManager.clients.get(
(memcache_module, url),
_memcache_module.Client,
url.split(';'))
def get_creation_lock(self, key):
return file_synchronizer(
identifier="memcachedcontainer/funclock/%s/%s" %
(self.namespace, key), lock_dir=self.lock_dir)
def _format_key(self, key):
if not isinstance(key, str):
key = key.decode('ascii')
formated_key = (self.namespace + '_' + key).replace(' ', '\302\267')
if len(formated_key) > MAX_KEY_LENGTH:
if not PY2:
formated_key = formated_key.encode('utf-8')
formated_key = sha1(formated_key).hexdigest()
return formated_key
def __getitem__(self, key):
return self.mc.get(self._format_key(key))
def __contains__(self, key):
value = self.mc.get(self._format_key(key))
return value is not None
def has_key(self, key):
return key in self
def set_value(self, key, value, expiretime=None):
if expiretime:
self.mc.set(self._format_key(key), value, time=expiretime)
else:
self.mc.set(self._format_key(key), value)
def __setitem__(self, key, value):
self.set_value(key, value)
def __delitem__(self, key):
self.mc.delete(self._format_key(key))
def do_remove(self):
self.mc.flush_all()
def keys(self):
raise NotImplementedError(
"Memcache caching does not "
"support iteration of all cache keys")
class PyLibMCNamespaceManager(MemcachedNamespaceManager):
"""Provide thread-local support for pylibmc."""
pools = SyncDict()
def __init__(self, *arg, **kw):
super(PyLibMCNamespaceManager, self).__init__(*arg, **kw)
memcache_module = kw.get('memcache_module', 'auto')
_memcache_module = _client_libs[memcache_module]
protocol = kw.get('protocol', 'text')
username = kw.get('username', None)
password = kw.get('password', None)
url = kw.get('url')
behaviors = parse_memcached_behaviors(kw)
self.mc = MemcachedNamespaceManager.clients.get(
(memcache_module, url),
_memcache_module.Client,
servers=url.split(';'), behaviors=behaviors,
binary=(protocol == 'binary'), username=username,
password=password)
self.pool = PyLibMCNamespaceManager.pools.get(
(memcache_module, url),
pylibmc.ThreadMappedPool, self.mc)
def __getitem__(self, key):
with self.pool.reserve() as mc:
return mc.get(self._format_key(key))
def __contains__(self, key):
with self.pool.reserve() as mc:
value = mc.get(self._format_key(key))
return value is not None
def has_key(self, key):
return key in self
def set_value(self, key, value, expiretime=None):
with self.pool.reserve() as mc:
if expiretime:
mc.set(self._format_key(key), value, time=expiretime)
else:
mc.set(self._format_key(key), value)
def __setitem__(self, key, value):
self.set_value(key, value)
def __delitem__(self, key):
with self.pool.reserve() as mc:
mc.delete(self._format_key(key))
def do_remove(self):
with self.pool.reserve() as mc:
mc.flush_all()
class MemcachedContainer(Container):
"""Container class which invokes :class:`.MemcacheNamespaceManager`."""
namespace_class = MemcachedNamespaceManager

View file

@ -1,184 +0,0 @@
import datetime
import os
import threading
import time
import pickle
try:
import pymongo
import pymongo.errors
import bson
except ImportError:
pymongo = None
bson = None
from beaker.container import NamespaceManager
from beaker.synchronization import SynchronizerImpl
from beaker.util import SyncDict, machine_identifier
from beaker.crypto.util import sha1
from beaker._compat import string_type, PY2
class MongoNamespaceManager(NamespaceManager):
"""Provides the :class:`.NamespaceManager` API over MongoDB.
Provided ``url`` can be both a mongodb connection string or
an already existing MongoClient instance.
The data will be stored into ``beaker_cache`` collection of the
*default database*, so make sure your connection string or
MongoClient point to a default database.
"""
MAX_KEY_LENGTH = 1024
clients = SyncDict()
def __init__(self, namespace, url, **kw):
super(MongoNamespaceManager, self).__init__(namespace)
self.lock_dir = None # MongoDB uses mongo itself for locking.
if pymongo is None:
raise RuntimeError('pymongo3 is not available')
if isinstance(url, string_type):
self.client = MongoNamespaceManager.clients.get(url, pymongo.MongoClient, url)
else:
self.client = url
self.db = self.client.get_default_database()
def _format_key(self, key):
if not isinstance(key, str):
key = key.decode('ascii')
if len(key) > (self.MAX_KEY_LENGTH - len(self.namespace) - 1):
if not PY2:
key = key.encode('utf-8')
key = sha1(key).hexdigest()
return '%s:%s' % (self.namespace, key)
def get_creation_lock(self, key):
return MongoSynchronizer(self._format_key(key), self.client)
def __getitem__(self, key):
self._clear_expired()
entry = self.db.backer_cache.find_one({'_id': self._format_key(key)})
if entry is None:
raise KeyError(key)
return pickle.loads(entry['value'])
def __contains__(self, key):
self._clear_expired()
entry = self.db.backer_cache.find_one({'_id': self._format_key(key)})
return entry is not None
def has_key(self, key):
return key in self
def set_value(self, key, value, expiretime=None):
self._clear_expired()
expiration = None
if expiretime is not None:
expiration = time.time() + expiretime
value = pickle.dumps(value)
self.db.backer_cache.update_one({'_id': self._format_key(key)},
{'$set': {'value': bson.Binary(value),
'expiration': expiration}},
upsert=True)
def __setitem__(self, key, value):
self.set_value(key, value)
def __delitem__(self, key):
self._clear_expired()
self.db.backer_cache.delete_many({'_id': self._format_key(key)})
def do_remove(self):
self.db.backer_cache.delete_many({'_id': {'$regex': '^%s' % self.namespace}})
def keys(self):
return [e['key'].split(':', 1)[-1] for e in self.db.backer_cache.find_all(
{'_id': {'$regex': '^%s' % self.namespace}}
)]
def _clear_expired(self):
now = time.time()
self.db.backer_cache.delete_many({'_id': {'$regex': '^%s' % self.namespace},
'expiration': {'$ne': None, '$lte': now}})
class MongoSynchronizer(SynchronizerImpl):
"""Provides a Writer/Reader lock based on MongoDB.
Provided ``url`` can be both a mongodb connection string or
an already existing MongoClient instance.
The data will be stored into ``beaker_locks`` collection of the
*default database*, so make sure your connection string or
MongoClient point to a default database.
Locks are identified by local machine, PID and threadid, so
are suitable for use in both local and distributed environments.
"""
# If a cache entry generation function can take a lot,
# but 15 minutes is more than a reasonable time.
LOCK_EXPIRATION = 900
MACHINE_ID = machine_identifier()
def __init__(self, identifier, url):
super(MongoSynchronizer, self).__init__()
self.identifier = identifier
if isinstance(url, string_type):
self.client = MongoNamespaceManager.clients.get(url, pymongo.MongoClient, url)
else:
self.client = url
self.db = self.client.get_default_database()
def _clear_expired_locks(self):
now = datetime.datetime.utcnow()
expired = now - datetime.timedelta(seconds=self.LOCK_EXPIRATION)
self.db.beaker_locks.delete_many({'_id': self.identifier, 'timestamp': {'$lte': expired}})
return now
def _get_owner_id(self):
return '%s-%s-%s' % (self.MACHINE_ID, os.getpid(), threading.current_thread().ident)
def do_release_read_lock(self):
owner_id = self._get_owner_id()
self.db.beaker_locks.update_one({'_id': self.identifier, 'readers': owner_id},
{'$pull': {'readers': owner_id}})
def do_acquire_read_lock(self, wait):
now = self._clear_expired_locks()
owner_id = self._get_owner_id()
while True:
try:
self.db.beaker_locks.update_one({'_id': self.identifier, 'owner': None},
{'$set': {'timestamp': now},
'$push': {'readers': owner_id}},
upsert=True)
return True
except pymongo.errors.DuplicateKeyError:
if not wait:
return False
time.sleep(0.2)
def do_release_write_lock(self):
self.db.beaker_locks.delete_one({'_id': self.identifier, 'owner': self._get_owner_id()})
def do_acquire_write_lock(self, wait):
now = self._clear_expired_locks()
owner_id = self._get_owner_id()
while True:
try:
self.db.beaker_locks.update_one({'_id': self.identifier, 'owner': None,
'readers': []},
{'$set': {'owner': owner_id,
'timestamp': now}},
upsert=True)
return True
except pymongo.errors.DuplicateKeyError:
if not wait:
return False
time.sleep(0.2)

View file

@ -1,144 +0,0 @@
import os
import threading
import time
import pickle
try:
import redis
except ImportError:
redis = None
from beaker.container import NamespaceManager
from beaker.synchronization import SynchronizerImpl
from beaker.util import SyncDict, machine_identifier
from beaker.crypto.util import sha1
from beaker._compat import string_type, PY2
class RedisNamespaceManager(NamespaceManager):
"""Provides the :class:`.NamespaceManager` API over Redis.
Provided ``url`` can be both a redis connection string or
an already existing StrictRedis instance.
The data will be stored into redis keys, with their name
starting with ``beaker_cache:``. So make sure you provide
a specific database number if you don't want to mix them
with your own data.
"""
MAX_KEY_LENGTH = 1024
clients = SyncDict()
def __init__(self, namespace, url, timeout=None, **kw):
super(RedisNamespaceManager, self).__init__(namespace)
self.lock_dir = None # Redis uses redis itself for locking.
self.timeout = timeout
if redis is None:
raise RuntimeError('redis is not available')
if isinstance(url, string_type):
self.client = RedisNamespaceManager.clients.get(url, redis.StrictRedis.from_url, url)
else:
self.client = url
def _format_key(self, key):
if not isinstance(key, str):
key = key.decode('ascii')
if len(key) > (self.MAX_KEY_LENGTH - len(self.namespace) - len('beaker_cache:') - 1):
if not PY2:
key = key.encode('utf-8')
key = sha1(key).hexdigest()
return 'beaker_cache:%s:%s' % (self.namespace, key)
def get_creation_lock(self, key):
return RedisSynchronizer(self._format_key(key), self.client)
def __getitem__(self, key):
entry = self.client.get(self._format_key(key))
if entry is None:
raise KeyError(key)
return pickle.loads(entry)
def __contains__(self, key):
return self.client.exists(self._format_key(key))
def has_key(self, key):
return key in self
def set_value(self, key, value, expiretime=None):
value = pickle.dumps(value)
if expiretime is None and self.timeout is not None:
expiretime = self.timeout
if expiretime is not None:
self.client.setex(self._format_key(key), int(expiretime), value)
else:
self.client.set(self._format_key(key), value)
def __setitem__(self, key, value):
self.set_value(key, value)
def __delitem__(self, key):
self.client.delete(self._format_key(key))
def do_remove(self):
for k in self.keys():
self.client.delete(k)
def keys(self):
return self.client.keys('beaker_cache:%s:*' % self.namespace)
class RedisSynchronizer(SynchronizerImpl):
"""Synchronizer based on redis.
Provided ``url`` can be both a redis connection string or
an already existing StrictRedis instance.
This Synchronizer only supports 1 reader or 1 writer at time, not concurrent readers.
"""
# If a cache entry generation function can take a lot,
# but 15 minutes is more than a reasonable time.
LOCK_EXPIRATION = 900
MACHINE_ID = machine_identifier()
def __init__(self, identifier, url):
super(RedisSynchronizer, self).__init__()
self.identifier = 'beaker_lock:%s' % identifier
if isinstance(url, string_type):
self.client = RedisNamespaceManager.clients.get(url, redis.StrictRedis.from_url, url)
else:
self.client = url
def _get_owner_id(self):
return (
'%s-%s-%s' % (self.MACHINE_ID, os.getpid(), threading.current_thread().ident)
).encode('ascii')
def do_release_read_lock(self):
self.do_release_write_lock()
def do_acquire_read_lock(self, wait):
self.do_acquire_write_lock(wait)
def do_release_write_lock(self):
identifier = self.identifier
owner_id = self._get_owner_id()
def execute_release(pipe):
lock_value = pipe.get(identifier)
if lock_value == owner_id:
pipe.delete(identifier)
self.client.transaction(execute_release, identifier)
def do_acquire_write_lock(self, wait):
owner_id = self._get_owner_id()
while True:
if self.client.setnx(self.identifier, owner_id):
self.client.pexpire(self.identifier, self.LOCK_EXPIRATION * 1000)
return True
if not wait:
return False
time.sleep(0.2)

View file

@ -1,137 +0,0 @@
from beaker._compat import pickle
import logging
import pickle
from datetime import datetime
from beaker.container import OpenResourceNamespaceManager, Container
from beaker.exceptions import InvalidCacheBackendError, MissingCacheParameter
from beaker.synchronization import file_synchronizer, null_synchronizer
from beaker.util import verify_directory, SyncDict
log = logging.getLogger(__name__)
sa = None
class SqlaNamespaceManager(OpenResourceNamespaceManager):
binds = SyncDict()
tables = SyncDict()
@classmethod
def _init_dependencies(cls):
global sa
if sa is not None:
return
try:
import sqlalchemy as sa
except ImportError:
raise InvalidCacheBackendError("SQLAlchemy, which is required by "
"this backend, is not installed")
def __init__(self, namespace, bind, table, data_dir=None, lock_dir=None,
**kwargs):
"""Create a namespace manager for use with a database table via
SQLAlchemy.
``bind``
SQLAlchemy ``Engine`` or ``Connection`` object
``table``
SQLAlchemy ``Table`` object in which to store namespace data.
This should usually be something created by ``make_cache_table``.
"""
OpenResourceNamespaceManager.__init__(self, namespace)
if lock_dir:
self.lock_dir = lock_dir
elif data_dir:
self.lock_dir = data_dir + "/container_db_lock"
if self.lock_dir:
verify_directory(self.lock_dir)
self.bind = self.__class__.binds.get(str(bind.url), lambda: bind)
self.table = self.__class__.tables.get('%s:%s' % (bind.url, table.name),
lambda: table)
self.hash = {}
self._is_new = False
self.loaded = False
def get_access_lock(self):
return null_synchronizer()
def get_creation_lock(self, key):
return file_synchronizer(
identifier="databasecontainer/funclock/%s" % self.namespace,
lock_dir=self.lock_dir)
def do_open(self, flags, replace):
if self.loaded:
self.flags = flags
return
select = sa.select([self.table.c.data],
(self.table.c.namespace == self.namespace))
result = self.bind.execute(select).fetchone()
if not result:
self._is_new = True
self.hash = {}
else:
self._is_new = False
try:
self.hash = result['data']
except (IOError, OSError, EOFError, pickle.PickleError,
pickle.PickleError):
log.debug("Couln't load pickle data, creating new storage")
self.hash = {}
self._is_new = True
self.flags = flags
self.loaded = True
def do_close(self):
if self.flags is not None and (self.flags == 'c' or self.flags == 'w'):
if self._is_new:
insert = self.table.insert()
self.bind.execute(insert, namespace=self.namespace, data=self.hash,
accessed=datetime.now(), created=datetime.now())
self._is_new = False
else:
update = self.table.update(self.table.c.namespace == self.namespace)
self.bind.execute(update, data=self.hash, accessed=datetime.now())
self.flags = None
def do_remove(self):
delete = self.table.delete(self.table.c.namespace == self.namespace)
self.bind.execute(delete)
self.hash = {}
self._is_new = True
def __getitem__(self, key):
return self.hash[key]
def __contains__(self, key):
return key in self.hash
def __setitem__(self, key, value):
self.hash[key] = value
def __delitem__(self, key):
del self.hash[key]
def keys(self):
return self.hash.keys()
class SqlaContainer(Container):
namespace_manager = SqlaNamespaceManager
def make_cache_table(metadata, table_name='beaker_cache', schema_name=None):
"""Return a ``Table`` object suitable for storing cached values for the
namespace manager. Do not create the table."""
return sa.Table(table_name, metadata,
sa.Column('namespace', sa.String(255), primary_key=True),
sa.Column('accessed', sa.DateTime, nullable=False),
sa.Column('created', sa.DateTime, nullable=False),
sa.Column('data', sa.PickleType, nullable=False),
schema=schema_name if schema_name else metadata.schema)

View file

@ -1,169 +0,0 @@
import warnings
try:
from paste.registry import StackedObjectProxy
beaker_session = StackedObjectProxy(name="Beaker Session")
beaker_cache = StackedObjectProxy(name="Cache Manager")
except:
beaker_cache = None
beaker_session = None
from beaker.cache import CacheManager
from beaker.session import Session, SessionObject
from beaker.util import coerce_cache_params, coerce_session_params, \
parse_cache_config_options
class CacheMiddleware(object):
cache = beaker_cache
def __init__(self, app, config=None, environ_key='beaker.cache', **kwargs):
"""Initialize the Cache Middleware
The Cache middleware will make a CacheManager instance available
every request under the ``environ['beaker.cache']`` key by
default. The location in environ can be changed by setting
``environ_key``.
``config``
dict All settings should be prefixed by 'cache.'. This
method of passing variables is intended for Paste and other
setups that accumulate multiple component settings in a
single dictionary. If config contains *no cache. prefixed
args*, then *all* of the config options will be used to
intialize the Cache objects.
``environ_key``
Location where the Cache instance will keyed in the WSGI
environ
``**kwargs``
All keyword arguments are assumed to be cache settings and
will override any settings found in ``config``
"""
self.app = app
config = config or {}
self.options = {}
# Update the options with the parsed config
self.options.update(parse_cache_config_options(config))
# Add any options from kwargs, but leave out the defaults this
# time
self.options.update(
parse_cache_config_options(kwargs, include_defaults=False))
# Assume all keys are intended for cache if none are prefixed with
# 'cache.'
if not self.options and config:
self.options = config
self.options.update(kwargs)
self.cache_manager = CacheManager(**self.options)
self.environ_key = environ_key
def __call__(self, environ, start_response):
if environ.get('paste.registry'):
if environ['paste.registry'].reglist:
environ['paste.registry'].register(self.cache,
self.cache_manager)
environ[self.environ_key] = self.cache_manager
return self.app(environ, start_response)
class SessionMiddleware(object):
session = beaker_session
def __init__(self, wrap_app, config=None, environ_key='beaker.session',
**kwargs):
"""Initialize the Session Middleware
The Session middleware will make a lazy session instance
available every request under the ``environ['beaker.session']``
key by default. The location in environ can be changed by
setting ``environ_key``.
``config``
dict All settings should be prefixed by 'session.'. This
method of passing variables is intended for Paste and other
setups that accumulate multiple component settings in a
single dictionary. If config contains *no session. prefixed
args*, then *all* of the config options will be used to
intialize the Session objects.
``environ_key``
Location where the Session instance will keyed in the WSGI
environ
``**kwargs``
All keyword arguments are assumed to be session settings and
will override any settings found in ``config``
"""
config = config or {}
# Load up the default params
self.options = dict(invalidate_corrupt=True, type=None,
data_dir=None, key='beaker.session.id',
timeout=None, save_accessed_time=True, secret=None,
log_file=None)
# Pull out any config args meant for beaker session. if there are any
for dct in [config, kwargs]:
for key, val in dct.items():
if key.startswith('beaker.session.'):
self.options[key[15:]] = val
if key.startswith('session.'):
self.options[key[8:]] = val
if key.startswith('session_'):
warnings.warn('Session options should start with session. '
'instead of session_.', DeprecationWarning, 2)
self.options[key[8:]] = val
# Coerce and validate session params
coerce_session_params(self.options)
# Assume all keys are intended for session if none are prefixed with
# 'session.'
if not self.options and config:
self.options = config
self.options.update(kwargs)
self.wrap_app = self.app = wrap_app
self.environ_key = environ_key
def __call__(self, environ, start_response):
session = SessionObject(environ, **self.options)
if environ.get('paste.registry'):
if environ['paste.registry'].reglist:
environ['paste.registry'].register(self.session, session)
environ[self.environ_key] = session
environ['beaker.get_session'] = self._get_session
if 'paste.testing_variables' in environ and 'webtest_varname' in self.options:
environ['paste.testing_variables'][self.options['webtest_varname']] = session
def session_start_response(status, headers, exc_info=None):
if session.accessed():
session.persist()
if session.__dict__['_headers']['set_cookie']:
cookie = session.__dict__['_headers']['cookie_out']
if cookie:
headers.append(('Set-cookie', cookie))
return start_response(status, headers, exc_info)
return self.wrap_app(environ, session_start_response)
def _get_session(self):
return Session({}, use_cookies=False, **self.options)
def session_filter_factory(global_conf, **kwargs):
def filter(app):
return SessionMiddleware(app, global_conf, **kwargs)
return filter
def session_filter_app_factory(app, global_conf, **kwargs):
return SessionMiddleware(app, global_conf, **kwargs)

View file

@ -1,845 +0,0 @@
from ._compat import PY2, pickle, http_cookies, unicode_text, b64encode, b64decode, string_type
import os
import time
from datetime import datetime, timedelta
from beaker.crypto import hmac as HMAC, hmac_sha1 as SHA1, sha1, get_nonce_size, DEFAULT_NONCE_BITS, get_crypto_module
from beaker import crypto, util
from beaker.cache import clsmap
from beaker.exceptions import BeakerException, InvalidCryptoBackendError
from beaker.cookie import SimpleCookie
__all__ = ['SignedCookie', 'Session', 'InvalidSignature']
class _InvalidSignatureType(object):
"""Returned from SignedCookie when the value's signature was invalid."""
def __nonzero__(self):
return False
def __bool__(self):
return False
InvalidSignature = _InvalidSignatureType()
try:
import uuid
def _session_id():
return uuid.uuid4().hex
except ImportError:
import random
if hasattr(os, 'getpid'):
getpid = os.getpid
else:
def getpid():
return ''
def _session_id():
id_str = "%f%s%f%s" % (
time.time(),
id({}),
random.random(),
getpid()
)
# NB: nothing against second parameter to b64encode, but it seems
# to be slower than simple chained replacement
if not PY2:
raw_id = b64encode(sha1(id_str.encode('ascii')).digest())
return str(raw_id.replace(b'+', b'-').replace(b'/', b'_').rstrip(b'='))
else:
raw_id = b64encode(sha1(id_str).digest())
return raw_id.replace('+', '-').replace('/', '_').rstrip('=')
class SignedCookie(SimpleCookie):
"""Extends python cookie to give digital signature support"""
def __init__(self, secret, input=None):
self.secret = secret.encode('UTF-8')
http_cookies.BaseCookie.__init__(self, input)
def value_decode(self, val):
val = val.strip('"')
if not val:
return None, val
sig = HMAC.new(self.secret, val[40:].encode('utf-8'), SHA1).hexdigest()
# Avoid timing attacks
invalid_bits = 0
input_sig = val[:40]
if len(sig) != len(input_sig):
return InvalidSignature, val
for a, b in zip(sig, input_sig):
invalid_bits += a != b
if invalid_bits:
return InvalidSignature, val
else:
return val[40:], val
def value_encode(self, val):
sig = HMAC.new(self.secret, val.encode('utf-8'), SHA1).hexdigest()
return str(val), ("%s%s" % (sig, val))
class Session(dict):
"""Session object that uses container package for storage.
:param invalidate_corrupt: How to handle corrupt data when loading. When
set to True, then corrupt data will be silently
invalidated and a new session created,
otherwise invalid data will cause an exception.
:type invalidate_corrupt: bool
:param use_cookies: Whether or not cookies should be created. When set to
False, it is assumed the user will handle storing the
session on their own.
:type use_cookies: bool
:param type: What data backend type should be used to store the underlying
session data
:param key: The name the cookie should be set to.
:param timeout: How long session data is considered valid. This is used
regardless of the cookie being present or not to determine
whether session data is still valid. Can be set to None to
disable session time out.
:type timeout: int or None
:param save_accessed_time: Whether beaker should save the session's access
time (True) or only modification time (False).
Defaults to True.
:param cookie_expires: Expiration date for cookie
:param cookie_domain: Domain to use for the cookie.
:param cookie_path: Path to use for the cookie.
:param data_serializer: If ``"json"`` or ``"pickle"`` should be used
to serialize data. Can also be an object with
``loads` and ``dumps`` methods. By default
``"pickle"`` is used.
:param secure: Whether or not the cookie should only be sent over SSL.
:param httponly: Whether or not the cookie should only be accessible by
the browser not by JavaScript.
:param encrypt_key: The key to use for the local session encryption, if not
provided the session will not be encrypted.
:param validate_key: The key used to sign the local encrypted session
:param encrypt_nonce_bits: Number of bits used to generate nonce for encryption key salt.
For security reason this is 128bits be default. If you want
to keep backward compatibility with sessions generated before 1.8.0
set this to 48.
:param crypto_type: encryption module to use
:param samesite: SameSite value for the cookie -- should be either 'Lax',
'Strict', or None.
"""
def __init__(self, request, id=None, invalidate_corrupt=False,
use_cookies=True, type=None, data_dir=None,
key='beaker.session.id', timeout=None, save_accessed_time=True,
cookie_expires=True, cookie_domain=None, cookie_path='/',
data_serializer='pickle', secret=None,
secure=False, namespace_class=None, httponly=False,
encrypt_key=None, validate_key=None, encrypt_nonce_bits=DEFAULT_NONCE_BITS,
crypto_type='default', samesite='Lax',
**namespace_args):
if not type:
if data_dir:
self.type = 'file'
else:
self.type = 'memory'
else:
self.type = type
self.namespace_class = namespace_class or clsmap[self.type]
self.namespace_args = namespace_args
self.request = request
self.data_dir = data_dir
self.key = key
if timeout and not save_accessed_time:
raise BeakerException("timeout requires save_accessed_time")
self.timeout = timeout
# If a timeout was provided, forward it to the backend too, so the backend
# can automatically expire entries if it's supported.
if self.timeout is not None:
# The backend expiration should always be a bit longer than the
# session expiration itself to prevent the case where the backend data expires while
# the session is being read (PR#153). 2 Minutes seems a reasonable time.
self.namespace_args['timeout'] = self.timeout + 60 * 2
self.save_atime = save_accessed_time
self.use_cookies = use_cookies
self.cookie_expires = cookie_expires
self._set_serializer(data_serializer)
# Default cookie domain/path
self._domain = cookie_domain
self._path = cookie_path
self.was_invalidated = False
self.secret = secret
self.secure = secure
self.httponly = httponly
self.samesite = samesite
self.encrypt_key = encrypt_key
self.validate_key = validate_key
self.encrypt_nonce_size = get_nonce_size(encrypt_nonce_bits)
self.crypto_module = get_crypto_module(crypto_type)
self.id = id
self.accessed_dict = {}
self.invalidate_corrupt = invalidate_corrupt
if self.use_cookies:
cookieheader = request.get('cookie', '')
if secret:
try:
self.cookie = SignedCookie(
secret,
input=cookieheader,
)
except http_cookies.CookieError:
self.cookie = SignedCookie(
secret,
input=None,
)
else:
self.cookie = SimpleCookie(input=cookieheader)
if not self.id and self.key in self.cookie:
cookie_data = self.cookie[self.key].value
# Should we check invalidate_corrupt here?
if cookie_data is InvalidSignature:
cookie_data = None
self.id = cookie_data
self.is_new = self.id is None
if self.is_new:
self._create_id()
self['_accessed_time'] = self['_creation_time'] = time.time()
else:
try:
self.load()
except Exception as e:
if self.invalidate_corrupt:
util.warn(
"Invalidating corrupt session %s; "
"error was: %s. Set invalidate_corrupt=False "
"to propagate this exception." % (self.id, e))
self.invalidate()
else:
raise
def _set_serializer(self, data_serializer):
self.data_serializer = data_serializer
if self.data_serializer == 'json':
self.serializer = util.JsonSerializer()
elif self.data_serializer == 'pickle':
self.serializer = util.PickleSerializer()
elif isinstance(self.data_serializer, string_type):
raise BeakerException('Invalid value for data_serializer: %s' % data_serializer)
else:
self.serializer = data_serializer
def has_key(self, name):
return name in self
def _set_cookie_values(self, expires=None):
self.cookie[self.key] = self.id
if self._domain:
self.cookie[self.key]['domain'] = self._domain
if self.secure:
self.cookie[self.key]['secure'] = True
if self.samesite:
self.cookie[self.key]['samesite'] = self.samesite
self._set_cookie_http_only()
self.cookie[self.key]['path'] = self._path
self._set_cookie_expires(expires)
def _set_cookie_expires(self, expires):
if expires is None:
expires = self.cookie_expires
if expires is False:
expires_date = datetime.fromtimestamp(0x7FFFFFFF)
elif isinstance(expires, timedelta):
expires_date = datetime.utcnow() + expires
elif isinstance(expires, datetime):
expires_date = expires
elif expires is not True:
raise ValueError("Invalid argument for cookie_expires: %s"
% repr(self.cookie_expires))
self.cookie_expires = expires
if not self.cookie or self.key not in self.cookie:
self.cookie[self.key] = self.id
if expires is True:
self.cookie[self.key]['expires'] = ''
return True
self.cookie[self.key]['expires'] = \
expires_date.strftime("%a, %d-%b-%Y %H:%M:%S GMT")
return expires_date
def _update_cookie_out(self, set_cookie=True):
self._set_cookie_values()
self.request['cookie_out'] = self.cookie[self.key].output(header='')
self.request['set_cookie'] = set_cookie
def _set_cookie_http_only(self):
try:
if self.httponly:
self.cookie[self.key]['httponly'] = True
except http_cookies.CookieError as e:
if 'Invalid Attribute httponly' not in str(e):
raise
util.warn('Python 2.6+ is required to use httponly')
def _create_id(self, set_new=True):
self.id = _session_id()
if set_new:
self.is_new = True
self.last_accessed = None
if self.use_cookies:
sc = set_new is False
self._update_cookie_out(set_cookie=sc)
@property
def created(self):
return self['_creation_time']
def _set_domain(self, domain):
self['_domain'] = self._domain = domain
self._update_cookie_out()
def _get_domain(self):
return self._domain
domain = property(_get_domain, _set_domain)
def _set_path(self, path):
self['_path'] = self._path = path
self._update_cookie_out()
def _get_path(self):
return self._path
path = property(_get_path, _set_path)
def _encrypt_data(self, session_data=None):
"""Serialize, encipher, and base64 the session dict"""
session_data = session_data or self.copy()
if self.encrypt_key:
nonce_len, nonce_b64len = self.encrypt_nonce_size
nonce = b64encode(os.urandom(nonce_len))[:nonce_b64len]
encrypt_key = crypto.generateCryptoKeys(self.encrypt_key,
self.validate_key + nonce,
1,
self.crypto_module.getKeyLength())
data = self.serializer.dumps(session_data)
return nonce + b64encode(self.crypto_module.aesEncrypt(data, encrypt_key))
else:
data = self.serializer.dumps(session_data)
return b64encode(data)
def _decrypt_data(self, session_data):
"""Base64, decipher, then un-serialize the data for the session
dict"""
if self.encrypt_key:
__, nonce_b64len = self.encrypt_nonce_size
nonce = session_data[:nonce_b64len]
encrypt_key = crypto.generateCryptoKeys(self.encrypt_key,
self.validate_key + nonce,
1,
self.crypto_module.getKeyLength())
payload = b64decode(session_data[nonce_b64len:])
data = self.crypto_module.aesDecrypt(payload, encrypt_key)
else:
data = b64decode(session_data)
return self.serializer.loads(data)
def _delete_cookie(self):
self.request['set_cookie'] = True
expires = datetime.utcnow() - timedelta(365)
self._set_cookie_values(expires)
self._update_cookie_out()
def delete(self):
"""Deletes the session from the persistent storage, and sends
an expired cookie out"""
if self.use_cookies:
self._delete_cookie()
self.clear()
def invalidate(self):
"""Invalidates this session, creates a new session id, returns
to the is_new state"""
self.clear()
self.was_invalidated = True
self._create_id()
self.load()
def load(self):
"Loads the data from this session from persistent storage"
self.namespace = self.namespace_class(self.id,
data_dir=self.data_dir,
digest_filenames=False,
**self.namespace_args)
now = time.time()
if self.use_cookies:
self.request['set_cookie'] = True
self.namespace.acquire_read_lock()
timed_out = False
try:
self.clear()
try:
session_data = self.namespace['session']
if (session_data is not None and self.encrypt_key):
session_data = self._decrypt_data(session_data)
# Memcached always returns a key, its None when its not
# present
if session_data is None:
session_data = {
'_creation_time': now,
'_accessed_time': now
}
self.is_new = True
except (KeyError, TypeError):
session_data = {
'_creation_time': now,
'_accessed_time': now
}
self.is_new = True
if session_data is None or len(session_data) == 0:
session_data = {
'_creation_time': now,
'_accessed_time': now
}
self.is_new = True
if self.timeout is not None and \
now - session_data['_accessed_time'] > self.timeout:
timed_out = True
else:
# Properly set the last_accessed time, which is different
# than the *currently* _accessed_time
if self.is_new or '_accessed_time' not in session_data:
self.last_accessed = None
else:
self.last_accessed = session_data['_accessed_time']
# Update the current _accessed_time
session_data['_accessed_time'] = now
# Set the path if applicable
if '_path' in session_data:
self._path = session_data['_path']
self.update(session_data)
self.accessed_dict = session_data.copy()
finally:
self.namespace.release_read_lock()
if timed_out:
self.invalidate()
def save(self, accessed_only=False):
"""Saves the data for this session to persistent storage
If accessed_only is True, then only the original data loaded
at the beginning of the request will be saved, with the updated
last accessed time.
"""
# Look to see if its a new session that was only accessed
# Don't save it under that case
if accessed_only and (self.is_new or not self.save_atime):
return None
# this session might not have a namespace yet or the session id
# might have been regenerated
if not hasattr(self, 'namespace') or self.namespace.namespace != self.id:
self.namespace = self.namespace_class(
self.id,
data_dir=self.data_dir,
digest_filenames=False,
**self.namespace_args)
self.namespace.acquire_write_lock(replace=True)
try:
if accessed_only:
data = dict(self.accessed_dict.items())
else:
data = dict(self.items())
if self.encrypt_key:
data = self._encrypt_data(data)
# Save the data
if not data and 'session' in self.namespace:
del self.namespace['session']
else:
self.namespace['session'] = data
finally:
self.namespace.release_write_lock()
if self.use_cookies and self.is_new:
self.request['set_cookie'] = True
def revert(self):
"""Revert the session to its original state from its first
access in the request"""
self.clear()
self.update(self.accessed_dict)
def regenerate_id(self):
"""
creates a new session id, retains all session data
Its a good security practice to regnerate the id after a client
elevates privileges.
"""
self._create_id(set_new=False)
# TODO: I think both these methods should be removed. They're from
# the original mod_python code i was ripping off but they really
# have no use here.
def lock(self):
"""Locks this session against other processes/threads. This is
automatic when load/save is called.
***use with caution*** and always with a corresponding 'unlock'
inside a "finally:" block, as a stray lock typically cannot be
unlocked without shutting down the whole application.
"""
self.namespace.acquire_write_lock()
def unlock(self):
"""Unlocks this session against other processes/threads. This
is automatic when load/save is called.
***use with caution*** and always within a "finally:" block, as
a stray lock typically cannot be unlocked without shutting down
the whole application.
"""
self.namespace.release_write_lock()
class CookieSession(Session):
"""Pure cookie-based session
Options recognized when using cookie-based sessions are slightly
more restricted than general sessions.
:param key: The name the cookie should be set to.
:param timeout: How long session data is considered valid. This is used
regardless of the cookie being present or not to determine
whether session data is still valid.
:type timeout: int
:param save_accessed_time: Whether beaker should save the session's access
time (True) or only modification time (False).
Defaults to True.
:param cookie_expires: Expiration date for cookie
:param cookie_domain: Domain to use for the cookie.
:param cookie_path: Path to use for the cookie.
:param data_serializer: If ``"json"`` or ``"pickle"`` should be used
to serialize data. Can also be an object with
``loads` and ``dumps`` methods. By default
``"pickle"`` is used.
:param secure: Whether or not the cookie should only be sent over SSL.
:param httponly: Whether or not the cookie should only be accessible by
the browser not by JavaScript.
:param encrypt_key: The key to use for the local session encryption, if not
provided the session will not be encrypted.
:param validate_key: The key used to sign the local encrypted session
:param invalidate_corrupt: How to handle corrupt data when loading. When
set to True, then corrupt data will be silently
invalidated and a new session created,
otherwise invalid data will cause an exception.
:type invalidate_corrupt: bool
:param crypto_type: The crypto module to use.
:param samesite: SameSite value for the cookie -- should be either 'Lax',
'Strict', or None.
"""
def __init__(self, request, key='beaker.session.id', timeout=None,
save_accessed_time=True, cookie_expires=True, cookie_domain=None,
cookie_path='/', encrypt_key=None, validate_key=None, secure=False,
httponly=False, data_serializer='pickle',
encrypt_nonce_bits=DEFAULT_NONCE_BITS, invalidate_corrupt=False,
crypto_type='default', samesite='Lax',
**kwargs):
self.crypto_module = get_crypto_module(crypto_type)
if encrypt_key and not self.crypto_module.has_aes:
raise InvalidCryptoBackendError("No AES library is installed, can't generate "
"encrypted cookie-only Session.")
self.request = request
self.key = key
self.timeout = timeout
self.save_atime = save_accessed_time
self.cookie_expires = cookie_expires
self.encrypt_key = encrypt_key
self.validate_key = validate_key
self.encrypt_nonce_size = get_nonce_size(encrypt_nonce_bits)
self.request['set_cookie'] = False
self.secure = secure
self.httponly = httponly
self.samesite = samesite
self._domain = cookie_domain
self._path = cookie_path
self.invalidate_corrupt = invalidate_corrupt
self._set_serializer(data_serializer)
try:
cookieheader = request['cookie']
except KeyError:
cookieheader = ''
if validate_key is None:
raise BeakerException("No validate_key specified for Cookie only "
"Session.")
if timeout and not save_accessed_time:
raise BeakerException("timeout requires save_accessed_time")
try:
self.cookie = SignedCookie(
validate_key,
input=cookieheader,
)
except http_cookies.CookieError:
self.cookie = SignedCookie(
validate_key,
input=None,
)
self['_id'] = _session_id()
self.is_new = True
# If we have a cookie, load it
if self.key in self.cookie and self.cookie[self.key].value is not None:
self.is_new = False
try:
cookie_data = self.cookie[self.key].value
if cookie_data is InvalidSignature:
raise BeakerException("Invalid signature")
self.update(self._decrypt_data(cookie_data))
self._path = self.get('_path', '/')
except Exception as e:
if self.invalidate_corrupt:
util.warn(
"Invalidating corrupt session %s; "
"error was: %s. Set invalidate_corrupt=False "
"to propagate this exception." % (self.id, e))
self.invalidate()
else:
raise
if self.timeout is not None:
now = time.time()
last_accessed_time = self.get('_accessed_time', now)
if now - last_accessed_time > self.timeout:
self.clear()
self.accessed_dict = self.copy()
self._create_cookie()
def created(self):
return self['_creation_time']
created = property(created)
def id(self):
return self['_id']
id = property(id)
def _set_domain(self, domain):
self['_domain'] = domain
self._domain = domain
def _get_domain(self):
return self._domain
domain = property(_get_domain, _set_domain)
def _set_path(self, path):
self['_path'] = self._path = path
def _get_path(self):
return self._path
path = property(_get_path, _set_path)
def save(self, accessed_only=False):
"""Saves the data for this session to persistent storage"""
if accessed_only and (self.is_new or not self.save_atime):
return
if accessed_only:
self.clear()
self.update(self.accessed_dict)
self._create_cookie()
def expire(self):
"""Delete the 'expires' attribute on this Session, if any."""
self.pop('_expires', None)
def _create_cookie(self):
if '_creation_time' not in self:
self['_creation_time'] = time.time()
if '_id' not in self:
self['_id'] = _session_id()
self['_accessed_time'] = time.time()
val = self._encrypt_data()
if len(val) > 4064:
raise BeakerException("Cookie value is too long to store")
self.cookie[self.key] = val
if '_expires' in self:
expires = self['_expires']
else:
expires = None
expires = self._set_cookie_expires(expires)
if expires is not None:
self['_expires'] = expires
if '_domain' in self:
self.cookie[self.key]['domain'] = self['_domain']
elif self._domain:
self.cookie[self.key]['domain'] = self._domain
if self.secure:
self.cookie[self.key]['secure'] = True
self._set_cookie_http_only()
self.cookie[self.key]['path'] = self.get('_path', '/')
self.request['cookie_out'] = self.cookie[self.key].output(header='')
self.request['set_cookie'] = True
def delete(self):
"""Delete the cookie, and clear the session"""
# Send a delete cookie request
self._delete_cookie()
self.clear()
def invalidate(self):
"""Clear the contents and start a new session"""
self.clear()
self['_id'] = _session_id()
class SessionObject(object):
"""Session proxy/lazy creator
This object proxies access to the actual session object, so that in
the case that the session hasn't been used before, it will be
setup. This avoid creating and loading the session from persistent
storage unless its actually used during the request.
"""
def __init__(self, environ, **params):
self.__dict__['_params'] = params
self.__dict__['_environ'] = environ
self.__dict__['_sess'] = None
self.__dict__['_headers'] = {}
def _session(self):
"""Lazy initial creation of session object"""
if self.__dict__['_sess'] is None:
params = self.__dict__['_params']
environ = self.__dict__['_environ']
self.__dict__['_headers'] = req = {'cookie_out': None}
req['cookie'] = environ.get('HTTP_COOKIE')
session_cls = params.get('session_class', None)
if session_cls is None:
if params.get('type') == 'cookie':
session_cls = CookieSession
else:
session_cls = Session
else:
assert issubclass(session_cls, Session),\
"Not a Session: " + session_cls
self.__dict__['_sess'] = session_cls(req, **params)
return self.__dict__['_sess']
def __getattr__(self, attr):
return getattr(self._session(), attr)
def __setattr__(self, attr, value):
setattr(self._session(), attr, value)
def __delattr__(self, name):
self._session().__delattr__(name)
def __getitem__(self, key):
return self._session()[key]
def __setitem__(self, key, value):
self._session()[key] = value
def __delitem__(self, key):
self._session().__delitem__(key)
def __repr__(self):
return self._session().__repr__()
def __iter__(self):
"""Only works for proxying to a dict"""
return iter(self._session().keys())
def __contains__(self, key):
return key in self._session()
def has_key(self, key):
return key in self._session()
def get_by_id(self, id):
"""Loads a session given a session ID"""
params = self.__dict__['_params']
session = Session({}, use_cookies=False, id=id, **params)
if session.is_new:
return None
return session
def save(self):
self.__dict__['_dirty'] = True
def delete(self):
self.__dict__['_dirty'] = True
self._session().delete()
def persist(self):
"""Persist the session to the storage
Always saves the whole session if save() or delete() have been called.
If they haven't:
- If autosave is set to true, saves the the entire session regardless.
- If save_accessed_time is set to true or unset, only saves the updated
access time.
- If save_accessed_time is set to false, doesn't save anything.
"""
if self.__dict__['_params'].get('auto'):
self._session().save()
elif self.__dict__['_params'].get('save_accessed_time', True):
if self.dirty():
self._session().save()
else:
self._session().save(accessed_only=True)
else: # save_accessed_time is false
if self.dirty():
self._session().save()
def dirty(self):
"""Returns True if save() or delete() have been called"""
return self.__dict__.get('_dirty', False)
def accessed(self):
"""Returns whether or not the session has been accessed"""
return self.__dict__['_sess'] is not None

View file

@ -1,392 +0,0 @@
"""Synchronization functions.
File- and mutex-based mutual exclusion synchronizers are provided,
as well as a name-based mutex which locks within an application
based on a string name.
"""
import errno
import os
import sys
import tempfile
try:
import threading as _threading
except ImportError:
import dummy_threading as _threading
# check for fcntl module
try:
sys.getwindowsversion()
has_flock = False
except:
try:
import fcntl
has_flock = True
except ImportError:
has_flock = False
from beaker import util
from beaker.exceptions import LockError
__all__ = ["file_synchronizer", "mutex_synchronizer", "null_synchronizer",
"NameLock", "_threading"]
class NameLock(object):
"""a proxy for an RLock object that is stored in a name based
registry.
Multiple threads can get a reference to the same RLock based on the
name alone, and synchronize operations related to that name.
"""
locks = util.WeakValuedRegistry()
class NLContainer(object):
def __init__(self, reentrant):
if reentrant:
self.lock = _threading.RLock()
else:
self.lock = _threading.Lock()
def __call__(self):
return self.lock
def __init__(self, identifier=None, reentrant=False):
if identifier is None:
self._lock = NameLock.NLContainer(reentrant)
else:
self._lock = NameLock.locks.get(identifier, NameLock.NLContainer,
reentrant)
def acquire(self, wait=True):
return self._lock().acquire(wait)
def release(self):
self._lock().release()
_synchronizers = util.WeakValuedRegistry()
def _synchronizer(identifier, cls, **kwargs):
return _synchronizers.sync_get((identifier, cls), cls, identifier, **kwargs)
def file_synchronizer(identifier, **kwargs):
if not has_flock or 'lock_dir' not in kwargs:
return mutex_synchronizer(identifier)
else:
return _synchronizer(identifier, FileSynchronizer, **kwargs)
def mutex_synchronizer(identifier, **kwargs):
return _synchronizer(identifier, ConditionSynchronizer, **kwargs)
class null_synchronizer(object):
"""A 'null' synchronizer, which provides the :class:`.SynchronizerImpl` interface
without any locking.
"""
def acquire_write_lock(self, wait=True):
return True
def acquire_read_lock(self):
pass
def release_write_lock(self):
pass
def release_read_lock(self):
pass
acquire = acquire_write_lock
release = release_write_lock
class SynchronizerImpl(object):
"""Base class for a synchronization object that allows
multiple readers, single writers.
"""
def __init__(self):
self._state = util.ThreadLocal()
class SyncState(object):
__slots__ = 'reentrantcount', 'writing', 'reading'
def __init__(self):
self.reentrantcount = 0
self.writing = False
self.reading = False
def state(self):
if not self._state.has():
state = SynchronizerImpl.SyncState()
self._state.put(state)
return state
else:
return self._state.get()
state = property(state)
def release_read_lock(self):
state = self.state
if state.writing:
raise LockError("lock is in writing state")
if not state.reading:
raise LockError("lock is not in reading state")
if state.reentrantcount == 1:
self.do_release_read_lock()
state.reading = False
state.reentrantcount -= 1
def acquire_read_lock(self, wait=True):
state = self.state
if state.writing:
raise LockError("lock is in writing state")
if state.reentrantcount == 0:
x = self.do_acquire_read_lock(wait)
if (wait or x):
state.reentrantcount += 1
state.reading = True
return x
elif state.reading:
state.reentrantcount += 1
return True
def release_write_lock(self):
state = self.state
if state.reading:
raise LockError("lock is in reading state")
if not state.writing:
raise LockError("lock is not in writing state")
if state.reentrantcount == 1:
self.do_release_write_lock()
state.writing = False
state.reentrantcount -= 1
release = release_write_lock
def acquire_write_lock(self, wait=True):
state = self.state
if state.reading:
raise LockError("lock is in reading state")
if state.reentrantcount == 0:
x = self.do_acquire_write_lock(wait)
if (wait or x):
state.reentrantcount += 1
state.writing = True
return x
elif state.writing:
state.reentrantcount += 1
return True
acquire = acquire_write_lock
def do_release_read_lock(self):
raise NotImplementedError()
def do_acquire_read_lock(self, wait):
raise NotImplementedError()
def do_release_write_lock(self):
raise NotImplementedError()
def do_acquire_write_lock(self, wait):
raise NotImplementedError()
class FileSynchronizer(SynchronizerImpl):
"""A synchronizer which locks using flock().
"""
def __init__(self, identifier, lock_dir):
super(FileSynchronizer, self).__init__()
self._filedescriptor = util.ThreadLocal()
if lock_dir is None:
lock_dir = tempfile.gettempdir()
else:
lock_dir = lock_dir
self.filename = util.encoded_path(
lock_dir,
[identifier],
extension='.lock'
)
self.lock_dir = os.path.dirname(self.filename)
def _filedesc(self):
return self._filedescriptor.get()
_filedesc = property(_filedesc)
def _ensuredir(self):
if not os.path.exists(self.lock_dir):
util.verify_directory(self.lock_dir)
def _open(self, mode):
filedescriptor = self._filedesc
if filedescriptor is None:
self._ensuredir()
filedescriptor = os.open(self.filename, mode)
self._filedescriptor.put(filedescriptor)
return filedescriptor
def do_acquire_read_lock(self, wait):
filedescriptor = self._open(os.O_CREAT | os.O_RDONLY)
if not wait:
try:
fcntl.flock(filedescriptor, fcntl.LOCK_SH | fcntl.LOCK_NB)
return True
except IOError:
os.close(filedescriptor)
self._filedescriptor.remove()
return False
else:
fcntl.flock(filedescriptor, fcntl.LOCK_SH)
return True
def do_acquire_write_lock(self, wait):
filedescriptor = self._open(os.O_CREAT | os.O_WRONLY)
if not wait:
try:
fcntl.flock(filedescriptor, fcntl.LOCK_EX | fcntl.LOCK_NB)
return True
except IOError:
os.close(filedescriptor)
self._filedescriptor.remove()
return False
else:
fcntl.flock(filedescriptor, fcntl.LOCK_EX)
return True
def do_release_read_lock(self):
self._release_all_locks()
def do_release_write_lock(self):
self._release_all_locks()
def _release_all_locks(self):
filedescriptor = self._filedesc
if filedescriptor is not None:
fcntl.flock(filedescriptor, fcntl.LOCK_UN)
os.close(filedescriptor)
self._filedescriptor.remove()
class ConditionSynchronizer(SynchronizerImpl):
"""a synchronizer using a Condition."""
def __init__(self, identifier):
super(ConditionSynchronizer, self).__init__()
# counts how many asynchronous methods are executing
self.asynch = 0
# pointer to thread that is the current sync operation
self.current_sync_operation = None
# condition object to lock on
self.condition = _threading.Condition(_threading.Lock())
def do_acquire_read_lock(self, wait=True):
self.condition.acquire()
try:
# see if a synchronous operation is waiting to start
# or is already running, in which case we wait (or just
# give up and return)
if wait:
while self.current_sync_operation is not None:
self.condition.wait()
else:
if self.current_sync_operation is not None:
return False
self.asynch += 1
finally:
self.condition.release()
if not wait:
return True
def do_release_read_lock(self):
self.condition.acquire()
try:
self.asynch -= 1
# check if we are the last asynchronous reader thread
# out the door.
if self.asynch == 0:
# yes. so if a sync operation is waiting, notifyAll to wake
# it up
if self.current_sync_operation is not None:
self.condition.notifyAll()
elif self.asynch < 0:
raise LockError("Synchronizer error - too many "
"release_read_locks called")
finally:
self.condition.release()
def do_acquire_write_lock(self, wait=True):
self.condition.acquire()
try:
# here, we are not a synchronous reader, and after returning,
# assuming waiting or immediate availability, we will be.
if wait:
# if another sync is working, wait
while self.current_sync_operation is not None:
self.condition.wait()
else:
# if another sync is working,
# we dont want to wait, so forget it
if self.current_sync_operation is not None:
return False
# establish ourselves as the current sync
# this indicates to other read/write operations
# that they should wait until this is None again
self.current_sync_operation = _threading.currentThread()
# now wait again for asyncs to finish
if self.asynch > 0:
if wait:
# wait
self.condition.wait()
else:
# we dont want to wait, so forget it
self.current_sync_operation = None
return False
finally:
self.condition.release()
if not wait:
return True
def do_release_write_lock(self):
self.condition.acquire()
try:
if self.current_sync_operation is not _threading.currentThread():
raise LockError("Synchronizer error - current thread doesnt "
"have the write lock")
# reset the current sync operation so
# another can get it
self.current_sync_operation = None
# tell everyone to get ready
self.condition.notifyAll()
finally:
# everyone go !!
self.condition.release()

View file

@ -1,507 +0,0 @@
"""Beaker utilities"""
import hashlib
import socket
import binascii
from ._compat import PY2, string_type, unicode_text, NoneType, dictkeyslist, im_class, im_func, pickle, func_signature, \
default_im_func
try:
import threading as _threading
except ImportError:
import dummy_threading as _threading
from datetime import datetime, timedelta
import os
import re
import string
import types
import weakref
import warnings
import sys
import inspect
import json
import zlib
from beaker.converters import asbool
from beaker import exceptions
from threading import local as _tlocal
DEFAULT_CACHE_KEY_LENGTH = 250
__all__ = ["ThreadLocal", "WeakValuedRegistry", "SyncDict", "encoded_path",
"verify_directory",
"serialize", "deserialize"]
def function_named(fn, name):
"""Return a function with a given __name__.
Will assign to __name__ and return the original function if possible on
the Python implementation, otherwise a new function will be constructed.
"""
fn.__name__ = name
return fn
def skip_if(predicate, reason=None):
"""Skip a test if predicate is true."""
reason = reason or predicate.__name__
from nose import SkipTest
def decorate(fn):
fn_name = fn.__name__
def maybe(*args, **kw):
if predicate():
msg = "'%s' skipped: %s" % (
fn_name, reason)
raise SkipTest(msg)
else:
return fn(*args, **kw)
return function_named(maybe, fn_name)
return decorate
def assert_raises(except_cls, callable_, *args, **kw):
"""Assert the given exception is raised by the given function + arguments."""
try:
callable_(*args, **kw)
success = False
except except_cls:
success = True
# assert outside the block so it works for AssertionError too !
assert success, "Callable did not raise an exception"
def verify_directory(dir):
"""verifies and creates a directory. tries to
ignore collisions with other threads and processes."""
tries = 0
while not os.access(dir, os.F_OK):
try:
tries += 1
os.makedirs(dir)
except:
if tries > 5:
raise
def has_self_arg(func):
"""Return True if the given function has a 'self' argument."""
args = list(func_signature(func).parameters)
if args and args[0] in ('self', 'cls'):
return True
else:
return False
def warn(msg, stacklevel=3):
"""Issue a warning."""
if isinstance(msg, string_type):
warnings.warn(msg, exceptions.BeakerWarning, stacklevel=stacklevel)
else:
warnings.warn(msg, stacklevel=stacklevel)
def deprecated(message):
def wrapper(fn):
def deprecated_method(*args, **kargs):
warnings.warn(message, DeprecationWarning, 2)
return fn(*args, **kargs)
# TODO: use decorator ? functools.wrapper ?
deprecated_method.__name__ = fn.__name__
deprecated_method.__doc__ = "%s\n\n%s" % (message, fn.__doc__)
return deprecated_method
return wrapper
class ThreadLocal(object):
"""stores a value on a per-thread basis"""
__slots__ = '_tlocal'
def __init__(self):
self._tlocal = _tlocal()
def put(self, value):
self._tlocal.value = value
def has(self):
return hasattr(self._tlocal, 'value')
def get(self, default=None):
return getattr(self._tlocal, 'value', default)
def remove(self):
del self._tlocal.value
class SyncDict(object):
"""
An efficient/threadsafe singleton map algorithm, a.k.a.
"get a value based on this key, and create if not found or not
valid" paradigm:
exists && isvalid ? get : create
Designed to work with weakref dictionaries to expect items
to asynchronously disappear from the dictionary.
Use python 2.3.3 or greater ! a major bug was just fixed in Nov.
2003 that was driving me nuts with garbage collection/weakrefs in
this section.
"""
def __init__(self):
self.mutex = _threading.Lock()
self.dict = {}
def get(self, key, createfunc, *args, **kwargs):
try:
if key in self.dict:
return self.dict[key]
else:
return self.sync_get(key, createfunc, *args, **kwargs)
except KeyError:
return self.sync_get(key, createfunc, *args, **kwargs)
def sync_get(self, key, createfunc, *args, **kwargs):
self.mutex.acquire()
try:
try:
if key in self.dict:
return self.dict[key]
else:
return self._create(key, createfunc, *args, **kwargs)
except KeyError:
return self._create(key, createfunc, *args, **kwargs)
finally:
self.mutex.release()
def _create(self, key, createfunc, *args, **kwargs):
self[key] = obj = createfunc(*args, **kwargs)
return obj
def has_key(self, key):
return key in self.dict
def __contains__(self, key):
return self.dict.__contains__(key)
def __getitem__(self, key):
return self.dict.__getitem__(key)
def __setitem__(self, key, value):
self.dict.__setitem__(key, value)
def __delitem__(self, key):
return self.dict.__delitem__(key)
def clear(self):
self.dict.clear()
class WeakValuedRegistry(SyncDict):
def __init__(self):
self.mutex = _threading.RLock()
self.dict = weakref.WeakValueDictionary()
sha1 = None
def encoded_path(root, identifiers, extension=".enc", depth=3,
digest_filenames=True):
"""Generate a unique file-accessible path from the given list of
identifiers starting at the given root directory."""
ident = "_".join(identifiers)
global sha1
if sha1 is None:
from beaker.crypto import sha1
if digest_filenames:
if isinstance(ident, unicode_text):
ident = sha1(ident.encode('utf-8')).hexdigest()
else:
ident = sha1(ident).hexdigest()
ident = os.path.basename(ident)
tokens = []
for d in range(1, depth):
tokens.append(ident[0:d])
dir = os.path.join(root, *tokens)
verify_directory(dir)
return os.path.join(dir, ident + extension)
def asint(obj):
if isinstance(obj, int):
return obj
elif isinstance(obj, string_type) and re.match(r'^\d+$', obj):
return int(obj)
else:
raise Exception("This is not a proper int")
def verify_options(opt, types, error):
if not isinstance(opt, types):
if not isinstance(types, tuple):
types = (types,)
coerced = False
for typ in types:
try:
if typ in (list, tuple):
opt = [x.strip() for x in opt.split(',')]
else:
if typ == bool:
typ = asbool
elif typ == int:
typ = asint
elif typ in (timedelta, datetime):
if not isinstance(opt, typ):
raise Exception("%s requires a timedelta type", typ)
opt = typ(opt)
coerced = True
except:
pass
if coerced:
break
if not coerced:
raise Exception(error)
elif isinstance(opt, str) and not opt.strip():
raise Exception("Empty strings are invalid for: %s" % error)
return opt
def verify_rules(params, ruleset):
for key, types, message in ruleset:
if key in params:
params[key] = verify_options(params[key], types, message)
return params
def coerce_session_params(params):
rules = [
('data_dir', (str, NoneType), "data_dir must be a string referring to a directory."),
('lock_dir', (str, NoneType), "lock_dir must be a string referring to a directory."),
('type', (str, NoneType), "Session type must be a string."),
('cookie_expires', (bool, datetime, timedelta, int),
"Cookie expires was not a boolean, datetime, int, or timedelta instance."),
('cookie_domain', (str, NoneType), "Cookie domain must be a string."),
('cookie_path', (str, NoneType), "Cookie path must be a string."),
('id', (str,), "Session id must be a string."),
('key', (str,), "Session key must be a string."),
('secret', (str, NoneType), "Session secret must be a string."),
('validate_key', (str, NoneType), "Session encrypt_key must be a string."),
('encrypt_key', (str, NoneType), "Session validate_key must be a string."),
('encrypt_nonce_bits', (int, NoneType), "Session encrypt_nonce_bits must be a number"),
('secure', (bool, NoneType), "Session secure must be a boolean."),
('httponly', (bool, NoneType), "Session httponly must be a boolean."),
('timeout', (int, NoneType), "Session timeout must be an integer."),
('save_accessed_time', (bool, NoneType),
"Session save_accessed_time must be a boolean (defaults to true)."),
('auto', (bool, NoneType), "Session is created if accessed."),
('webtest_varname', (str, NoneType), "Session varname must be a string."),
('data_serializer', (str,), "data_serializer must be a string.")
]
opts = verify_rules(params, rules)
cookie_expires = opts.get('cookie_expires')
if cookie_expires and isinstance(cookie_expires, int) and \
not isinstance(cookie_expires, bool):
opts['cookie_expires'] = timedelta(seconds=cookie_expires)
if opts.get('timeout') is not None and not opts.get('save_accessed_time', True):
raise Exception("save_accessed_time must be true to use timeout")
return opts
def coerce_cache_params(params):
rules = [
('data_dir', (str, NoneType), "data_dir must be a string referring to a directory."),
('lock_dir', (str, NoneType), "lock_dir must be a string referring to a directory."),
('type', (str,), "Cache type must be a string."),
('enabled', (bool, NoneType), "enabled must be true/false if present."),
('expire', (int, NoneType),
"expire must be an integer representing how many seconds the cache is valid for"),
('regions', (list, tuple, NoneType),
"Regions must be a comma separated list of valid regions"),
('key_length', (int, NoneType),
"key_length must be an integer which indicates the longest a key can be before hashing"),
]
return verify_rules(params, rules)
def coerce_memcached_behaviors(behaviors):
rules = [
('cas', (bool, int), 'cas must be a boolean or an integer'),
('no_block', (bool, int), 'no_block must be a boolean or an integer'),
('receive_timeout', (int,), 'receive_timeout must be an integer'),
('send_timeout', (int,), 'send_timeout must be an integer'),
('ketama_hash', (str,),
'ketama_hash must be a string designating a valid hashing strategy option'),
('_poll_timeout', (int,), '_poll_timeout must be an integer'),
('auto_eject', (bool, int), 'auto_eject must be an integer'),
('retry_timeout', (int,), 'retry_timeout must be an integer'),
('_sort_hosts', (bool, int), '_sort_hosts must be an integer'),
('_io_msg_watermark', (int,), '_io_msg_watermark must be an integer'),
('ketama', (bool, int), 'ketama must be a boolean or an integer'),
('ketama_weighted', (bool, int), 'ketama_weighted must be a boolean or an integer'),
('_io_key_prefetch', (int, bool), '_io_key_prefetch must be a boolean or an integer'),
('_hash_with_prefix_key', (bool, int),
'_hash_with_prefix_key must be a boolean or an integer'),
('tcp_nodelay', (bool, int), 'tcp_nodelay must be a boolean or an integer'),
('failure_limit', (int,), 'failure_limit must be an integer'),
('buffer_requests', (bool, int), 'buffer_requests must be a boolean or an integer'),
('_socket_send_size', (int,), '_socket_send_size must be an integer'),
('num_replicas', (int,), 'num_replicas must be an integer'),
('remove_failed', (int,), 'remove_failed must be an integer'),
('_noreply', (bool, int), '_noreply must be a boolean or an integer'),
('_io_bytes_watermark', (int,), '_io_bytes_watermark must be an integer'),
('_socket_recv_size', (int,), '_socket_recv_size must be an integer'),
('distribution', (str,),
'distribution must be a string designating a valid distribution option'),
('connect_timeout', (int,), 'connect_timeout must be an integer'),
('hash', (str,), 'hash must be a string designating a valid hashing option'),
('verify_keys', (bool, int), 'verify_keys must be a boolean or an integer'),
('dead_timeout', (int,), 'dead_timeout must be an integer')
]
return verify_rules(behaviors, rules)
def parse_cache_config_options(config, include_defaults=True):
"""Parse configuration options and validate for use with the
CacheManager"""
# Load default cache options
if include_defaults:
options = dict(type='memory', data_dir=None, expire=None,
log_file=None)
else:
options = {}
for key, val in config.items():
if key.startswith('beaker.cache.'):
options[key[13:]] = val
if key.startswith('cache.'):
options[key[6:]] = val
coerce_cache_params(options)
# Set cache to enabled if not turned off
if 'enabled' not in options and include_defaults:
options['enabled'] = True
# Configure region dict if regions are available
regions = options.pop('regions', None)
if regions:
region_configs = {}
for region in regions:
if not region: # ensure region name is valid
continue
# Setup the default cache options
region_options = dict(data_dir=options.get('data_dir'),
lock_dir=options.get('lock_dir'),
type=options.get('type'),
enabled=options['enabled'],
expire=options.get('expire'),
key_length=options.get('key_length', DEFAULT_CACHE_KEY_LENGTH))
region_prefix = '%s.' % region
region_len = len(region_prefix)
for key in dictkeyslist(options):
if key.startswith(region_prefix):
region_options[key[region_len:]] = options.pop(key)
coerce_cache_params(region_options)
region_configs[region] = region_options
options['cache_regions'] = region_configs
return options
def parse_memcached_behaviors(config):
"""Parse behavior options and validate for use with pylibmc
client/PylibMCNamespaceManager, or potentially other memcached
NamespaceManagers that support behaviors"""
behaviors = {}
for key, val in config.items():
if key.startswith('behavior.'):
behaviors[key[9:]] = val
coerce_memcached_behaviors(behaviors)
return behaviors
def func_namespace(func):
"""Generates a unique namespace for a function"""
kls = None
if hasattr(func, 'im_func') or hasattr(func, '__func__'):
kls = im_class(func)
func = im_func(func)
if kls:
return '%s.%s' % (kls.__module__, kls.__name__)
else:
return '%s|%s' % (inspect.getsourcefile(func), func.__name__)
class PickleSerializer(object):
def loads(self, data_string):
return pickle.loads(data_string)
def dumps(self, data):
return pickle.dumps(data, 2)
class JsonSerializer(object):
def loads(self, data_string):
return json.loads(zlib.decompress(data_string).decode('utf-8'))
def dumps(self, data):
return zlib.compress(json.dumps(data).encode('utf-8'))
def serialize(data, method):
if method == 'json':
serializer = JsonSerializer()
else:
serializer = PickleSerializer()
return serializer.dumps(data)
def deserialize(data_string, method):
if method == 'json':
serializer = JsonSerializer()
else:
serializer = PickleSerializer()
return serializer.loads(data_string)
def machine_identifier():
machine_hash = hashlib.md5()
if not PY2:
machine_hash.update(socket.gethostname().encode())
else:
machine_hash.update(socket.gethostname())
return binascii.hexlify(machine_hash.digest()[0:3]).decode('ascii')
def safe_write (filepath, contents):
if os.name == 'posix':
tempname = '%s.temp' % (filepath)
fh = open(tempname, 'wb')
fh.write(contents)
fh.close()
os.rename(tempname, filepath)
else:
fh = open(filepath, 'wb')
fh.write(contents)
fh.close()

View file

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
# Copyright 2009-2019 Joshua Bronson. All Rights Reserved.
# Copyright 2009-2021 Joshua Bronson. All Rights Reserved.
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
@ -26,8 +26,9 @@
#==============================================================================
"""
Efficient, Pythonic bidirectional map implementation and related functionality.
"""The bidirectional mapping library for Python.
bidict by example:
.. code-block:: python
@ -44,66 +45,45 @@ https://bidict.readthedocs.io for the most up-to-date documentation
if you are reading this elsewhere.
.. :copyright: (c) 2019 Joshua Bronson.
.. :copyright: (c) 2009-2021 Joshua Bronson.
.. :license: MPLv2. See LICENSE for details.
"""
# This __init__.py only collects functionality implemented in the rest of the
# source and exports it under the `bidict` module namespace (via `__all__`).
# Use private aliases to not re-export these publicly (for Sphinx automodule with imported-members).
from sys import version_info as _version_info
from ._abc import BidirectionalMapping
if _version_info < (3, 6): # pragma: no cover
raise ImportError('Python 3.6+ is required.')
from ._abc import BidirectionalMapping, MutableBidirectionalMapping
from ._base import BidictBase
from ._mut import MutableBidict
from ._bidict import bidict
from ._dup import DuplicationPolicy, IGNORE, OVERWRITE, RAISE
from ._exc import (
BidictException, DuplicationError,
KeyDuplicationError, ValueDuplicationError, KeyAndValueDuplicationError)
from ._util import inverted
from ._frozenbidict import frozenbidict
from ._frozenordered import FrozenOrderedBidict
from ._named import namedbidict
from ._orderedbase import OrderedBidictBase
from ._orderedbidict import OrderedBidict
from ._dup import ON_DUP_DEFAULT, ON_DUP_RAISE, ON_DUP_DROP_OLD, RAISE, DROP_OLD, DROP_NEW, OnDup, OnDupAction
from ._exc import BidictException, DuplicationError, KeyDuplicationError, ValueDuplicationError, KeyAndValueDuplicationError
from ._iter import inverted
from .metadata import (
__author__, __maintainer__, __copyright__, __email__, __credits__, __url__,
__license__, __status__, __description__, __keywords__, __version__, __version_info__)
__all__ = (
'__author__',
'__maintainer__',
'__copyright__',
'__email__',
'__credits__',
'__license__',
'__status__',
'__description__',
'__keywords__',
'__url__',
'__version__',
'__version_info__',
'BidirectionalMapping',
'BidictException',
'DuplicationPolicy',
'IGNORE',
'OVERWRITE',
'RAISE',
'DuplicationError',
'KeyDuplicationError',
'ValueDuplicationError',
'KeyAndValueDuplicationError',
'BidictBase',
'MutableBidict',
'frozenbidict',
'bidict',
'namedbidict',
'FrozenOrderedBidict',
'OrderedBidictBase',
'OrderedBidict',
'inverted',
__license__, __status__, __description__, __keywords__, __version__,
)
# Set __module__ of re-exported classes to the 'bidict' top-level module name
# so that private/internal submodules are not exposed to users e.g. in repr strings.
_locals = tuple(locals().items())
for _name, _obj in _locals: # pragma: no cover
if not getattr(_obj, '__module__', '').startswith('bidict.'):
continue
try:
_obj.__module__ = 'bidict'
except AttributeError: # raised when __module__ is read-only (as in OnDup)
pass
# * Code review nav *
#==============================================================================

View file

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
# Copyright 2009-2019 Joshua Bronson. All Rights Reserved.
# Copyright 2009-2021 Joshua Bronson. All Rights Reserved.
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
@ -26,12 +26,15 @@
#==============================================================================
"""Provides the :class:`BidirectionalMapping` abstract base class."""
"""Provide the :class:`BidirectionalMapping` abstract base class."""
from .compat import Mapping, abstractproperty, iteritems
import typing as _t
from abc import abstractmethod
from ._typing import KT, VT
class BidirectionalMapping(Mapping): # pylint: disable=abstract-method,no-init
class BidirectionalMapping(_t.Mapping[KT, VT]):
"""Abstract base class (ABC) for bidirectional mapping types.
Extends :class:`collections.abc.Mapping` primarily by adding the
@ -43,8 +46,9 @@ class BidirectionalMapping(Mapping): # pylint: disable=abstract-method,no-init
__slots__ = ()
@abstractproperty
def inverse(self):
@property
@abstractmethod
def inverse(self) -> 'BidirectionalMapping[VT, KT]':
"""The inverse of this bidirectional mapping instance.
*See also* :attr:`bidict.BidictBase.inverse`, :attr:`bidict.BidictBase.inv`
@ -58,7 +62,7 @@ class BidirectionalMapping(Mapping): # pylint: disable=abstract-method,no-init
# clear there's no reason to call this implementation (e.g. via super() after overriding).
raise NotImplementedError
def __inverted__(self):
def __inverted__(self) -> _t.Iterator[_t.Tuple[VT, KT]]:
"""Get an iterator over the items in :attr:`inverse`.
This is functionally equivalent to iterating over the items in the
@ -72,7 +76,27 @@ class BidirectionalMapping(Mapping): # pylint: disable=abstract-method,no-init
*See also* :func:`bidict.inverted`
"""
return iteritems(self.inverse)
return iter(self.inverse.items())
def values(self) -> _t.KeysView[VT]: # type: ignore [override] # https://github.com/python/typeshed/issues/4435
"""A set-like object providing a view on the contained values.
Override the implementation inherited from
:class:`~collections.abc.Mapping`.
Because the values of a :class:`~bidict.BidirectionalMapping`
are the keys of its inverse,
this returns a :class:`~collections.abc.KeysView`
rather than a :class:`~collections.abc.ValuesView`,
which has the advantages of constant-time containment checks
and supporting set operations.
"""
return self.inverse.keys() # type: ignore [return-value]
class MutableBidirectionalMapping(BidirectionalMapping[KT, VT], _t.MutableMapping[KT, VT]):
"""Abstract base class (ABC) for mutable bidirectional mapping types."""
__slots__ = ()
# * Code review nav *

View file

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
# Copyright 2009-2019 Joshua Bronson. All Rights Reserved.
# Copyright 2009-2021 Joshua Bronson. All Rights Reserved.
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
@ -22,139 +22,118 @@
# * Code review nav *
#==============================================================================
# ← Prev: _abc.py Current: _base.py Next: _delegating_mixins.py →
# ← Prev: _abc.py Current: _base.py Next: _frozenbidict.py →
#==============================================================================
"""Provides :class:`BidictBase`."""
"""Provide :class:`BidictBase`."""
import typing as _t
from collections import namedtuple
from copy import copy
from weakref import ref
from ._abc import BidirectionalMapping
from ._dup import RAISE, OVERWRITE, IGNORE, _OnDup
from ._exc import (
DuplicationError, KeyDuplicationError, ValueDuplicationError, KeyAndValueDuplicationError)
from ._miss import _MISS
from ._noop import _NOOP
from ._util import _iteritems_args_kw
from .compat import PY2, KeysView, ItemsView, Mapping, iteritems
from ._dup import ON_DUP_DEFAULT, RAISE, DROP_OLD, DROP_NEW, OnDup
from ._exc import DuplicationError, KeyDuplicationError, ValueDuplicationError, KeyAndValueDuplicationError
from ._iter import _iteritems_args_kw
from ._typing import _NONE, KT, VT, OKT, OVT, IterItems, MapOrIterItems
_DedupResult = namedtuple('_DedupResult', 'isdupkey isdupval invbyval fwdbykey')
_WriteResult = namedtuple('_WriteResult', 'key val oldkey oldval')
_NODUP = _DedupResult(False, False, _MISS, _MISS)
_DedupResult = namedtuple('_DedupResult', 'isdupkey isdupval invbyval fwdbykey')
_NODUP = _DedupResult(False, False, _NONE, _NONE)
BT = _t.TypeVar('BT', bound='BidictBase') # typevar for BidictBase.copy
class BidictBase(BidirectionalMapping):
class BidictBase(BidirectionalMapping[KT, VT]):
"""Base class implementing :class:`BidirectionalMapping`."""
__slots__ = ('_fwdm', '_invm', '_inv', '_invweak', '_hash') + (() if PY2 else ('__weakref__',))
__slots__ = ['_fwdm', '_invm', '_inv', '_invweak', '__weakref__']
#: The default :class:`DuplicationPolicy`
#: (in effect during e.g. :meth:`~bidict.bidict.__init__` calls)
#: The default :class:`~bidict.OnDup`
#: that governs behavior when a provided item
#: duplicates only the key of another item.
#:
#: Defaults to :attr:`~bidict.OVERWRITE`
#: to match :class:`dict`'s behavior.
#: duplicates the key or value of other item(s).
#:
#: *See also* :ref:`basic-usage:Values Must Be Unique`, :doc:`extending`
on_dup_key = OVERWRITE
on_dup = ON_DUP_DEFAULT
#: The default :class:`DuplicationPolicy`
#: (in effect during e.g. :meth:`~bidict.bidict.__init__` calls)
#: that governs behavior when a provided item
#: duplicates only the value of another item.
#:
#: Defaults to :attr:`~bidict.RAISE`
#: to prevent unintended overwrite of another item.
#:
#: *See also* :ref:`basic-usage:Values Must Be Unique`, :doc:`extending`
on_dup_val = RAISE
#: The default :class:`DuplicationPolicy`
#: (in effect during e.g. :meth:`~bidict.bidict.__init__` calls)
#: that governs behavior when a provided item
#: duplicates the key of another item and the value of a third item.
#:
#: Defaults to ``None``, which causes the *on_dup_kv* policy to match
#: whatever *on_dup_val* policy is in effect.
#:
#: *See also* :ref:`basic-usage:Values Must Be Unique`, :doc:`extending`
on_dup_kv = None
_fwdm_cls = dict
_invm_cls = dict
_fwdm_cls: _t.Type[_t.MutableMapping[KT, VT]] = dict #: class of the backing forward mapping
_invm_cls: _t.Type[_t.MutableMapping[VT, KT]] = dict #: class of the backing inverse mapping
#: The object used by :meth:`__repr__` for printing the contained items.
_repr_delegate = dict
_repr_delegate: _t.Callable = dict
def __init__(self, *args, **kw): # pylint: disable=super-init-not-called
_inv: 'BidictBase[VT, KT]'
_inv_cls: '_t.Type[BidictBase[VT, KT]]'
def __init_subclass__(cls, **kw):
super().__init_subclass__(**kw)
# Compute and set _inv_cls, the inverse of this bidict class.
if '_inv_cls' in cls.__dict__:
return
if cls._fwdm_cls is cls._invm_cls:
cls._inv_cls = cls
return
inv_cls = type(cls.__name__ + 'Inv', cls.__bases__, {
**cls.__dict__,
'_inv_cls': cls,
'_fwdm_cls': cls._invm_cls,
'_invm_cls': cls._fwdm_cls,
})
cls._inv_cls = inv_cls
@_t.overload
def __init__(self, __arg: _t.Mapping[KT, VT], **kw: VT) -> None: ...
@_t.overload
def __init__(self, __arg: IterItems[KT, VT], **kw: VT) -> None: ...
@_t.overload
def __init__(self, **kw: VT) -> None: ...
def __init__(self, *args: MapOrIterItems[KT, VT], **kw: VT) -> None:
"""Make a new bidirectional dictionary.
The signature is the same as that of regular dictionaries.
The signature behaves like that of :class:`dict`.
Items passed in are added in the order they are passed,
respecting the current duplication policies in the process.
*See also* :attr:`on_dup_key`, :attr:`on_dup_val`, :attr:`on_dup_kv`
respecting the :attr:`on_dup` class attribute in the process.
"""
#: The backing :class:`~collections.abc.Mapping`
#: storing the forward mapping data (*key* → *value*).
self._fwdm = self._fwdm_cls()
self._fwdm: _t.MutableMapping[KT, VT] = self._fwdm_cls()
#: The backing :class:`~collections.abc.Mapping`
#: storing the inverse mapping data (*value* → *key*).
self._invm = self._invm_cls()
self._init_inv() # lgtm [py/init-calls-subclass]
self._invm: _t.MutableMapping[VT, KT] = self._invm_cls()
self._init_inv()
if args or kw:
self._update(True, None, *args, **kw)
self._update(True, self.on_dup, *args, **kw)
def _init_inv(self):
# Compute the type for this bidict's inverse bidict (will be different from this
# bidict's type if _fwdm_cls and _invm_cls are different).
inv_cls = self._inv_cls()
def _init_inv(self) -> None:
# Create the inverse bidict instance via __new__, bypassing its __init__ so that its
# _fwdm and _invm can be assigned to this bidict's _invm and _fwdm. Store it in self._inv,
# which holds a strong reference to a bidict's inverse, if one is available.
self._inv = inv = inv_cls.__new__(inv_cls)
inv._fwdm = self._invm # pylint: disable=protected-access
inv._invm = self._fwdm # pylint: disable=protected-access
self._inv = inv = self._inv_cls.__new__(self._inv_cls)
inv._fwdm = self._invm
inv._invm = self._fwdm
# Only give the inverse a weak reference to this bidict to avoid creating a reference cycle,
# stored in the _invweak attribute. See also the docs in
# :ref:`addendum:Bidict Avoids Reference Cycles`
inv._inv = None # pylint: disable=protected-access
inv._invweak = ref(self) # pylint: disable=protected-access
inv._inv = None
inv._invweak = ref(self)
# Since this bidict has a strong reference to its inverse already, set its _invweak to None.
self._invweak = None
@classmethod
def _inv_cls(cls):
"""The inverse of this bidict type, i.e. one with *_fwdm_cls* and *_invm_cls* swapped."""
if cls._fwdm_cls is cls._invm_cls:
return cls
if not getattr(cls, '_inv_cls_', None):
class _Inv(cls):
_fwdm_cls = cls._invm_cls
_invm_cls = cls._fwdm_cls
_inv_cls_ = cls
_Inv.__name__ = cls.__name__ + 'Inv'
cls._inv_cls_ = _Inv
return cls._inv_cls_
@property
def _isinv(self):
def _isinv(self) -> bool:
return self._inv is None
@property
def inverse(self):
"""The inverse of this bidict.
*See also* :attr:`inv`
"""
def inverse(self) -> 'BidictBase[VT, KT]':
"""The inverse of this bidict."""
# Resolve and return a strong reference to the inverse bidict.
# One may be stored in self._inv already.
if self._inv is not None:
return self._inv
# Otherwise a weakref is stored in self._invweak. Try to get a strong ref from it.
assert self._invweak is not None
inv = self._invweak()
if inv is not None:
return inv
@ -162,12 +141,10 @@ class BidictBase(BidirectionalMapping):
self._init_inv() # Now this bidict will retain a strong ref to its inverse.
return self._inv
@property
def inv(self):
"""Alias for :attr:`inverse`."""
return self.inverse
#: Alias for :attr:`inverse`.
inv = inverse
def __getstate__(self):
def __getstate__(self) -> dict:
"""Needed to enable pickling due to use of :attr:`__slots__` and weakrefs.
*See also* :meth:`object.__getstate__`
@ -183,27 +160,27 @@ class BidictBase(BidirectionalMapping):
state.pop('__weakref__', None) # Not added back in __setstate__. Python manages this one.
return state
def __setstate__(self, state):
def __setstate__(self, state: dict) -> None:
"""Implemented because use of :attr:`__slots__` would prevent unpickling otherwise.
*See also* :meth:`object.__setstate__`
"""
for slot, value in iteritems(state):
for slot, value in state.items():
setattr(self, slot, value)
self._init_inv()
def __repr__(self):
def __repr__(self) -> str:
"""See :func:`repr`."""
clsname = self.__class__.__name__
if not self:
return '%s()' % clsname
return '%s(%r)' % (clsname, self._repr_delegate(iteritems(self)))
return f'{clsname}()'
return f'{clsname}({self._repr_delegate(self.items())})'
# The inherited Mapping.__eq__ implementation would work, but it's implemented in terms of an
# inefficient ``dict(self.items()) == dict(other.items())`` comparison, so override it with a
# more efficient implementation.
def __eq__(self, other):
u"""*x.__eq__(other)  x == other*
def __eq__(self, other: object) -> bool:
"""*x.__eq__(other)  x == other*
Equivalent to *dict(x.items()) == dict(other.items())*
but more efficient.
@ -216,101 +193,98 @@ class BidictBase(BidirectionalMapping):
*See also* :meth:`bidict.FrozenOrderedBidict.equals_order_sensitive`
"""
if not isinstance(other, Mapping) or len(self) != len(other):
if not isinstance(other, _t.Mapping) or len(self) != len(other):
return False
selfget = self.get
return all(selfget(k, _MISS) == v for (k, v) in iteritems(other))
return all(selfget(k, _NONE) == v for (k, v) in other.items()) # type: ignore [arg-type]
def equals_order_sensitive(self, other: object) -> bool:
"""Order-sensitive equality check.
*See also* :ref:`eq-order-insensitive`
"""
# Same short-circuit as in __eq__ above. Factoring out not worth function call overhead.
if not isinstance(other, _t.Mapping) or len(self) != len(other):
return False
return all(i == j for (i, j) in zip(self.items(), other.items()))
# The following methods are mutating and so are not public. But they are implemented in this
# non-mutable base class (rather than the mutable `bidict` subclass) because they are used here
# during initialization (starting with the `_update` method). (Why is this? Because `__init__`
# and `update` share a lot of the same behavior (inserting the provided items while respecting
# the active duplication policies), so it makes sense for them to share implementation too.)
def _pop(self, key):
# `on_dup`), so it makes sense for them to share implementation too.)
def _pop(self, key: KT) -> VT:
val = self._fwdm.pop(key)
del self._invm[val]
return val
def _put(self, key, val, on_dup):
def _put(self, key: KT, val: VT, on_dup: OnDup) -> None:
dedup_result = self._dedup_item(key, val, on_dup)
if dedup_result is not _NOOP:
if dedup_result is not None:
self._write_item(key, val, dedup_result)
def _dedup_item(self, key, val, on_dup):
"""
Check *key* and *val* for any duplication in self.
def _dedup_item(self, key: KT, val: VT, on_dup: OnDup) -> _t.Optional[_DedupResult]:
"""Check *key* and *val* for any duplication in self.
Handle any duplication as per the duplication policies given in *on_dup*.
Handle any duplication as per the passed in *on_dup*.
(key, val) already present is construed as a no-op, not a duplication.
If duplication is found and the corresponding duplication policy is
If duplication is found and the corresponding :class:`~bidict.OnDupAction` is
:attr:`~bidict.DROP_NEW`, return None.
If duplication is found and the corresponding :class:`~bidict.OnDupAction` is
:attr:`~bidict.RAISE`, raise the appropriate error.
If duplication is found and the corresponding duplication policy is
:attr:`~bidict.IGNORE`, return *None*.
If duplication is found and the corresponding duplication policy is
:attr:`~bidict.OVERWRITE`,
If duplication is found and the corresponding :class:`~bidict.OnDupAction` is
:attr:`~bidict.DROP_OLD`,
or if no duplication is found,
return the _DedupResult *(isdupkey, isdupval, oldkey, oldval)*.
return the :class:`_DedupResult` *(isdupkey, isdupval, oldkey, oldval)*.
"""
fwdm = self._fwdm
invm = self._invm
oldval = fwdm.get(key, _MISS)
oldkey = invm.get(val, _MISS)
isdupkey = oldval is not _MISS
isdupval = oldkey is not _MISS
oldval: OVT = fwdm.get(key, _NONE)
oldkey: OKT = invm.get(val, _NONE)
isdupkey = oldval is not _NONE
isdupval = oldkey is not _NONE
dedup_result = _DedupResult(isdupkey, isdupval, oldkey, oldval)
if isdupkey and isdupval:
if self._isdupitem(key, val, dedup_result):
if self._already_have(key, val, oldkey, oldval):
# (key, val) duplicates an existing item -> no-op.
return _NOOP
return None
# key and val each duplicate a different existing item.
if on_dup.kv is RAISE:
raise KeyAndValueDuplicationError(key, val)
elif on_dup.kv is IGNORE:
return _NOOP
assert on_dup.kv is OVERWRITE, 'invalid on_dup_kv: %r' % on_dup.kv
if on_dup.kv is DROP_NEW:
return None
assert on_dup.kv is DROP_OLD
# Fall through to the return statement on the last line.
elif isdupkey:
if on_dup.key is RAISE:
raise KeyDuplicationError(key)
elif on_dup.key is IGNORE:
return _NOOP
assert on_dup.key is OVERWRITE, 'invalid on_dup.key: %r' % on_dup.key
if on_dup.key is DROP_NEW:
return None
assert on_dup.key is DROP_OLD
# Fall through to the return statement on the last line.
elif isdupval:
if on_dup.val is RAISE:
raise ValueDuplicationError(val)
elif on_dup.val is IGNORE:
return _NOOP
assert on_dup.val is OVERWRITE, 'invalid on_dup.val: %r' % on_dup.val
if on_dup.val is DROP_NEW:
return None
assert on_dup.val is DROP_OLD
# Fall through to the return statement on the last line.
# else neither isdupkey nor isdupval.
return dedup_result
@staticmethod
def _isdupitem(key, val, dedup_result):
isdupkey, isdupval, oldkey, oldval = dedup_result
isdupitem = oldkey == key
assert isdupitem == (oldval == val), '%r %r %r' % (key, val, dedup_result)
if isdupitem:
assert isdupkey
assert isdupval
return isdupitem
def _already_have(key: KT, val: VT, oldkey: OKT, oldval: OVT) -> bool:
# Overridden by _orderedbase.OrderedBidictBase.
isdup = oldkey == key
assert isdup == (oldval == val), f'{key} {val} {oldkey} {oldval}'
return isdup
@classmethod
def _get_on_dup(cls, on_dup=None):
if on_dup is None:
on_dup = _OnDup(cls.on_dup_key, cls.on_dup_val, cls.on_dup_kv)
elif not isinstance(on_dup, _OnDup):
on_dup = _OnDup(*on_dup)
if on_dup.kv is None:
on_dup = on_dup._replace(kv=on_dup.val)
return on_dup
def _write_item(self, key, val, dedup_result):
def _write_item(self, key: KT, val: VT, dedup_result: _DedupResult) -> _WriteResult:
# Overridden by _orderedbase.OrderedBidictBase.
isdupkey, isdupval, oldkey, oldval = dedup_result
fwdm = self._fwdm
invm = self._invm
@ -322,35 +296,34 @@ class BidictBase(BidirectionalMapping):
del fwdm[oldkey]
return _WriteResult(key, val, oldkey, oldval)
def _update(self, init, on_dup, *args, **kw):
def _update(self, init: bool, on_dup: OnDup, *args: MapOrIterItems[KT, VT], **kw: VT) -> None:
# args[0] may be a generator that yields many items, so process input in a single pass.
if not args and not kw:
return
can_skip_dup_check = not self and not kw and isinstance(args[0], BidirectionalMapping)
if can_skip_dup_check:
self._update_no_dup_check(args[0])
self._update_no_dup_check(args[0]) # type: ignore [arg-type]
return
on_dup = self._get_on_dup(on_dup)
can_skip_rollback = init or RAISE not in on_dup
if can_skip_rollback:
self._update_no_rollback(on_dup, *args, **kw)
else:
self._update_with_rollback(on_dup, *args, **kw)
def _update_no_dup_check(self, other, _nodup=_NODUP):
def _update_no_dup_check(self, other: BidirectionalMapping[KT, VT]) -> None:
write_item = self._write_item
for (key, val) in iteritems(other):
write_item(key, val, _nodup)
for (key, val) in other.items():
write_item(key, val, _NODUP)
def _update_no_rollback(self, on_dup, *args, **kw):
def _update_no_rollback(self, on_dup: OnDup, *args: MapOrIterItems[KT, VT], **kw: VT) -> None:
put = self._put
for (key, val) in _iteritems_args_kw(*args, **kw):
put(key, val, on_dup)
def _update_with_rollback(self, on_dup, *args, **kw):
def _update_with_rollback(self, on_dup: OnDup, *args: MapOrIterItems[KT, VT], **kw: VT) -> None:
"""Update, rolling back on failure."""
writelog = []
appendlog = writelog.append
writes: _t.List[_t.Tuple[_DedupResult, _WriteResult]] = []
append_write = writes.append
dedup_item = self._dedup_item
write_item = self._write_item
for (key, val) in _iteritems_args_kw(*args, **kw):
@ -358,14 +331,14 @@ class BidictBase(BidirectionalMapping):
dedup_result = dedup_item(key, val, on_dup)
except DuplicationError:
undo_write = self._undo_write
for dedup_result, write_result in reversed(writelog):
for dedup_result, write_result in reversed(writes):
undo_write(dedup_result, write_result)
raise
if dedup_result is not _NOOP:
if dedup_result is not None:
write_result = write_item(key, val, dedup_result)
appendlog((dedup_result, write_result))
append_write((dedup_result, write_result))
def _undo_write(self, dedup_result, write_result):
def _undo_write(self, dedup_result: _DedupResult, write_result: _WriteResult) -> None:
isdupkey, isdupval, _, _ = dedup_result
key, val, oldkey, oldval = write_result
if not isdupkey and not isdupval:
@ -384,79 +357,48 @@ class BidictBase(BidirectionalMapping):
if not isdupkey:
del fwdm[key]
def copy(self):
def copy(self: BT) -> BT:
"""A shallow copy."""
# Could just ``return self.__class__(self)`` here instead, but the below is faster. It uses
# __new__ to create a copy instance while bypassing its __init__, which would result
# in copying this bidict's items into the copy instance one at a time. Instead, make whole
# copies of each of the backing mappings, and make them the backing mappings of the copy,
# avoiding copying items one at a time.
copy = self.__class__.__new__(self.__class__)
copy._fwdm = self._fwdm.copy() # pylint: disable=protected-access
copy._invm = self._invm.copy() # pylint: disable=protected-access
copy._init_inv() # pylint: disable=protected-access
return copy
cp: BT = self.__class__.__new__(self.__class__)
cp._fwdm = copy(self._fwdm)
cp._invm = copy(self._invm)
cp._init_inv()
return cp
def __copy__(self):
"""Used for the copy protocol.
#: Used for the copy protocol.
#: *See also* the :mod:`copy` module
__copy__ = copy
*See also* the :mod:`copy` module
"""
return self.copy()
def __len__(self):
def __len__(self) -> int:
"""The number of contained items."""
return len(self._fwdm)
def __iter__(self): # lgtm [py/inheritance/incorrect-overridden-signature]
"""Iterator over the contained items."""
# No default implementation for __iter__ inherited from Mapping ->
# always delegate to _fwdm.
def __iter__(self) -> _t.Iterator[KT]:
"""Iterator over the contained keys."""
return iter(self._fwdm)
def __getitem__(self, key):
u"""*x.__getitem__(key)  x[key]*"""
def __getitem__(self, key: KT) -> VT:
"""*x.__getitem__(key)  x[key]*"""
return self._fwdm[key]
def values(self):
"""A set-like object providing a view on the contained values.
# On Python 3.8+, dicts are reversible, so even non-Ordered bidicts can provide an efficient
# __reversed__ implementation. (On Python < 3.8, they cannot.) Once support is dropped for
# Python < 3.8, can remove the following if statement to provide __reversed__ unconditionally.
if hasattr(_fwdm_cls, '__reversed__'):
def __reversed__(self) -> _t.Iterator[KT]:
"""Iterator over the contained keys in reverse order."""
return reversed(self._fwdm) # type: ignore [no-any-return,call-overload]
Note that because the values of a :class:`~bidict.BidirectionalMapping`
are the keys of its inverse,
this returns a :class:`~collections.abc.KeysView`
rather than a :class:`~collections.abc.ValuesView`,
which has the advantages of constant-time containment checks
and supporting set operations.
"""
return self.inverse.keys()
if PY2:
# For iterkeys and iteritems, inheriting from Mapping already provides
# the best default implementations so no need to define here.
def itervalues(self):
"""An iterator over the contained values."""
return self.inverse.iterkeys()
def viewkeys(self): # noqa: D102; pylint: disable=missing-docstring
return KeysView(self)
def viewvalues(self): # noqa: D102; pylint: disable=missing-docstring
return self.inverse.viewkeys()
viewvalues.__doc__ = values.__doc__
values.__doc__ = 'A list of the contained values.'
def viewitems(self): # noqa: D102; pylint: disable=missing-docstring
return ItemsView(self)
# __ne__ added automatically in Python 3 when you implement __eq__, but not in Python 2.
def __ne__(self, other): # noqa: N802
u"""*x.__ne__(other)  x != other*"""
return not self == other # Implement __ne__ in terms of __eq__.
# Work around weakref slot with Generics bug on Python 3.6 (https://bugs.python.org/issue41451):
BidictBase.__slots__.remove('__weakref__')
# * Code review nav *
#==============================================================================
# ← Prev: _abc.py Current: _base.py Next: _delegating_mixins.py →
# ← Prev: _abc.py Current: _base.py Next: _frozenbidict.py →
#==============================================================================

View file

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
# Copyright 2009-2019 Joshua Bronson. All Rights Reserved.
# Copyright 2009-2021 Joshua Bronson. All Rights Reserved.
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
@ -26,18 +26,23 @@
#==============================================================================
"""Provides :class:`bidict`."""
"""Provide :class:`bidict`."""
import typing as _t
from ._delegating import _DelegatingBidict
from ._mut import MutableBidict
from ._delegating_mixins import _DelegateKeysAndItemsToFwdm
from ._typing import KT, VT
class bidict(_DelegateKeysAndItemsToFwdm, MutableBidict): # noqa: N801,E501; pylint: disable=invalid-name
class bidict(_DelegatingBidict[KT, VT], MutableBidict[KT, VT]):
"""Base class for mutable bidirectional mappings."""
__slots__ = ()
__hash__ = None # since this class is mutable; explicit > implicit.
if _t.TYPE_CHECKING:
@property
def inverse(self) -> 'bidict[VT, KT]': ...
# * Code review nav *

View file

@ -0,0 +1,39 @@
# -*- coding: utf-8 -*-
# Copyright 2009-2021 Joshua Bronson. All Rights Reserved.
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
"""Provide :class:`_DelegatingBidict`."""
import typing as _t
from ._base import BidictBase
from ._typing import KT, VT
class _DelegatingBidict(BidictBase[KT, VT]):
"""Provide optimized implementations of several methods by delegating to backing dicts.
Used to override less efficient implementations inherited by :class:`~collections.abc.Mapping`.
"""
__slots__ = ()
def __iter__(self) -> _t.Iterator[KT]:
"""Iterator over the contained keys."""
return iter(self._fwdm)
def keys(self) -> _t.KeysView[KT]:
"""A set-like object providing a view on the contained keys."""
return self._fwdm.keys() # type: ignore [return-value]
def values(self) -> _t.KeysView[VT]: # type: ignore [override] # https://github.com/python/typeshed/issues/4435
"""A set-like object providing a view on the contained values."""
return self._invm.keys() # type: ignore [return-value]
def items(self) -> _t.ItemsView[KT, VT]:
"""A set-like object providing a view on the contained items."""
return self._fwdm.items() # type: ignore [return-value]

View file

@ -1,92 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2009-2019 Joshua Bronson. All Rights Reserved.
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
#==============================================================================
# * Welcome to the bidict source code *
#==============================================================================
# Doing a code review? You'll find a "Code review nav" comment like the one
# below at the top and bottom of the most important source files. This provides
# a suggested initial path through the source when reviewing.
#
# Note: If you aren't reading this on https://github.com/jab/bidict, you may be
# viewing an outdated version of the code. Please head to GitHub to review the
# latest version, which contains important improvements over older versions.
#
# Thank you for reading and for any feedback you provide.
# * Code review nav *
#==============================================================================
# ← Prev: _base.py Current: _delegating_mixins.py Next: _frozenbidict.py →
#==============================================================================
r"""Provides mixin classes that delegate to ``self._fwdm`` for various operations.
This allows methods such as :meth:`bidict.bidict.items`
to be implemented in terms of a ``self._fwdm.items()`` call,
which is potentially much more efficient (e.g. in CPython 2)
compared to the implementation inherited from :class:`~collections.abc.Mapping`
(which returns ``[(key, self[key]) for key in self]`` in Python 2).
Because this depends on implementation details that aren't necessarily true
(such as the bidict's values being the same as its ``self._fwdm.values()``,
which is not true for e.g. ordered bidicts where ``_fwdm``\'s values are nodes),
these should always be mixed in at a layer below a more general layer,
as they are in e.g. :class:`~bidict.frozenbidict`
which extends :class:`~bidict.BidictBase`.
See the :ref:`extending:Sorted Bidict Recipes`
for another example of where this comes into play.
``SortedBidict`` extends :class:`bidict.MutableBidict`
rather than :class:`bidict.bidict`
to avoid inheriting these mixins,
which are incompatible with the backing
:class:`sortedcontainers.SortedDict`s.
"""
from .compat import PY2
_KEYS_METHODS = ('keys',) + (('viewkeys', 'iterkeys') if PY2 else ())
_ITEMS_METHODS = ('items',) + (('viewitems', 'iteritems') if PY2 else ())
_DOCSTRING_BY_METHOD = {
'keys': 'A set-like object providing a view on the contained keys.',
'items': 'A set-like object providing a view on the contained items.',
}
if PY2:
_DOCSTRING_BY_METHOD['viewkeys'] = _DOCSTRING_BY_METHOD['keys']
_DOCSTRING_BY_METHOD['viewitems'] = _DOCSTRING_BY_METHOD['items']
_DOCSTRING_BY_METHOD['keys'] = 'A list of the contained keys.'
_DOCSTRING_BY_METHOD['items'] = 'A list of the contained items.'
def _make_method(methodname):
def method(self):
return getattr(self._fwdm, methodname)() # pylint: disable=protected-access
method.__name__ = methodname
method.__doc__ = _DOCSTRING_BY_METHOD.get(methodname, '')
return method
def _make_fwdm_delegating_mixin(clsname, methodnames):
clsdict = dict({name: _make_method(name) for name in methodnames}, __slots__=())
return type(clsname, (object,), clsdict)
_DelegateKeysToFwdm = _make_fwdm_delegating_mixin('_DelegateKeysToFwdm', _KEYS_METHODS)
_DelegateItemsToFwdm = _make_fwdm_delegating_mixin('_DelegateItemsToFwdm', _ITEMS_METHODS)
_DelegateKeysAndItemsToFwdm = type(
'_DelegateKeysAndItemsToFwdm',
(_DelegateKeysToFwdm, _DelegateItemsToFwdm),
{'__slots__': ()})
# * Code review nav *
#==============================================================================
# ← Prev: _base.py Current: _delegating_mixins.py Next: _frozenbidict.py →
#==============================================================================

View file

@ -1,36 +1,58 @@
# -*- coding: utf-8 -*-
# Copyright 2009-2019 Joshua Bronson. All Rights Reserved.
# Copyright 2009-2021 Joshua Bronson. All Rights Reserved.
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
"""Provides bidict duplication policies and the :class:`_OnDup` class."""
"""Provide :class:`OnDup` and related functionality."""
from collections import namedtuple
from ._marker import _Marker
from enum import Enum
_OnDup = namedtuple('_OnDup', 'key val kv')
class OnDupAction(Enum):
"""An action to take to prevent duplication from occurring."""
#: Raise a :class:`~bidict.DuplicationError`.
RAISE = 'RAISE'
#: Overwrite existing items with new items.
DROP_OLD = 'DROP_OLD'
#: Keep existing items and drop new items.
DROP_NEW = 'DROP_NEW'
def __repr__(self) -> str:
return f'<{self.name}>'
class DuplicationPolicy(_Marker):
"""Base class for bidict's duplication policies.
RAISE = OnDupAction.RAISE
DROP_OLD = OnDupAction.DROP_OLD
DROP_NEW = OnDupAction.DROP_NEW
class OnDup(namedtuple('_OnDup', 'key val kv')):
r"""A 3-tuple of :class:`OnDupAction`\s specifying how to handle the 3 kinds of duplication.
*See also* :ref:`basic-usage:Values Must Be Unique`
If *kv* is not specified, *val* will be used for *kv*.
"""
__slots__ = ()
def __new__(cls, key: OnDupAction = DROP_OLD, val: OnDupAction = RAISE, kv: OnDupAction = RAISE) -> 'OnDup':
"""Override to provide user-friendly default values."""
return super().__new__(cls, key, val, kv or val)
#: Raise an exception when a duplication is encountered.
RAISE = DuplicationPolicy('DUP_POLICY.RAISE')
#: Overwrite an existing item when a duplication is encountered.
OVERWRITE = DuplicationPolicy('DUP_POLICY.OVERWRITE')
#: Keep the existing item and ignore the new item when a duplication is encountered.
IGNORE = DuplicationPolicy('DUP_POLICY.IGNORE')
#: Default :class:`OnDup` used for the
#: :meth:`~bidict.bidict.__init__`,
#: :meth:`~bidict.bidict.__setitem__`, and
#: :meth:`~bidict.bidict.update` methods.
ON_DUP_DEFAULT = OnDup()
#: An :class:`OnDup` whose members are all :obj:`RAISE`.
ON_DUP_RAISE = OnDup(key=RAISE, val=RAISE, kv=RAISE)
#: An :class:`OnDup` whose members are all :obj:`DROP_OLD`.
ON_DUP_DROP_OLD = OnDup(key=DROP_OLD, val=DROP_OLD, kv=DROP_OLD)

View file

@ -1,12 +1,12 @@
# -*- coding: utf-8 -*-
# Copyright 2009-2019 Joshua Bronson. All Rights Reserved.
# Copyright 2009-2021 Joshua Bronson. All Rights Reserved.
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
"""Provides all bidict exceptions."""
"""Provide all bidict exceptions."""
class BidictException(Exception):
@ -15,7 +15,7 @@ class BidictException(Exception):
class DuplicationError(BidictException):
"""Base class for exceptions raised when uniqueness is violated
as per the RAISE duplication policy.
as per the :attr:~bidict.RAISE` :class:`~bidict.OnDupAction`.
"""

View file

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
# Copyright 2009-2019 Joshua Bronson. All Rights Reserved.
# Copyright 2009-2021 Joshua Bronson. All Rights Reserved.
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
@ -22,30 +22,39 @@
# * Code review nav *
#==============================================================================
# ← Prev: _delegating_mixins.py Current: _frozenbidict.py Next: _mut.py →
# ← Prev: _base.py Current: _frozenbidict.py Next: _mut.py →
#==============================================================================
"""Provides :class:`frozenbidict`, an immutable, hashable bidirectional mapping type."""
"""Provide :class:`frozenbidict`, an immutable, hashable bidirectional mapping type."""
from ._base import BidictBase
from ._delegating_mixins import _DelegateKeysAndItemsToFwdm
from .compat import ItemsView
import typing as _t
from ._delegating import _DelegatingBidict
from ._typing import KT, VT
class frozenbidict(_DelegateKeysAndItemsToFwdm, BidictBase): # noqa: N801,E501; pylint: disable=invalid-name
class frozenbidict(_DelegatingBidict[KT, VT]):
"""Immutable, hashable bidict type."""
__slots__ = ()
__slots__ = ('_hash',)
def __hash__(self): # lgtm [py/equals-hash-mismatch]
_hash: int
# Work around lack of support for higher-kinded types in mypy.
# Ref: https://github.com/python/typing/issues/548#issuecomment-621571821
# Remove this and similar type stubs from other classes if support is ever added.
if _t.TYPE_CHECKING:
@property
def inverse(self) -> 'frozenbidict[VT, KT]': ...
def __hash__(self) -> int:
"""The hash of this bidict as determined by its items."""
if getattr(self, '_hash', None) is None:
# pylint: disable=protected-access,attribute-defined-outside-init
self._hash = ItemsView(self)._hash()
self._hash = _t.ItemsView(self)._hash() # type: ignore [attr-defined]
return self._hash
# * Code review nav *
#==============================================================================
# ← Prev: _delegating_mixins.py Current: _frozenbidict.py Next: _mut.py →
# ← Prev: _base.py Current: _frozenbidict.py Next: _mut.py →
#==============================================================================

View file

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
# Copyright 2009-2019 Joshua Bronson. All Rights Reserved.
# Copyright 2009-2021 Joshua Bronson. All Rights Reserved.
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
@ -25,38 +25,61 @@
#← Prev: _orderedbase.py Current: _frozenordered.py Next: _orderedbidict.py →
#==============================================================================
"""Provides :class:`FrozenOrderedBidict`, an immutable, hashable, ordered bidict."""
"""Provide :class:`FrozenOrderedBidict`, an immutable, hashable, ordered bidict."""
import typing as _t
from ._delegating_mixins import _DelegateKeysToFwdm
from ._frozenbidict import frozenbidict
from ._orderedbase import OrderedBidictBase
from .compat import DICTS_ORDERED, PY2, izip
from ._typing import KT, VT
# If the Python implementation's dict type is ordered (e.g. PyPy or CPython >= 3.6), then
# `FrozenOrderedBidict` can delegate to `_fwdm` for keys: Both `_fwdm` and `_invm` will always
# be initialized with the provided items in the correct order, and since `FrozenOrderedBidict`
# is immutable, their respective orders can't get out of sync after a mutation. (Can't delegate
# to `_fwdm` for items though because values in `_fwdm` are nodes.)
_BASES = ((_DelegateKeysToFwdm,) if DICTS_ORDERED else ()) + (OrderedBidictBase,)
_CLSDICT = dict(
__slots__=(),
# Must set __hash__ explicitly, Python prevents inheriting it.
# frozenbidict.__hash__ can be reused for FrozenOrderedBidict:
# FrozenOrderedBidict inherits BidictBase.__eq__ which is order-insensitive,
# and frozenbidict.__hash__ is consistent with BidictBase.__eq__.
__hash__=frozenbidict.__hash__.__func__ if PY2 else frozenbidict.__hash__,
__doc__='Hashable, immutable, ordered bidict type.',
__module__=__name__, # Otherwise unpickling fails in Python 2.
)
class FrozenOrderedBidict(OrderedBidictBase[KT, VT]):
"""Hashable, immutable, ordered bidict type.
# When PY2 (so we provide iteritems) and DICTS_ORDERED, e.g. on PyPy, the following implementation
# of iteritems may be more efficient than that inherited from `Mapping`. This exploits the property
# that the keys in `_fwdm` and `_invm` are already in the right order:
if PY2 and DICTS_ORDERED:
_CLSDICT['iteritems'] = lambda self: izip(self._fwdm, self._invm) # noqa: E501; pylint: disable=protected-access
Like a hashable :class:`bidict.OrderedBidict`
without the mutating APIs, or like a
reversible :class:`bidict.frozenbidict` even on Python < 3.8.
(All bidicts are order-preserving when never mutated, so frozenbidict is
already order-preserving, but only on Python 3.8+, where dicts are
reversible, are all bidicts (including frozenbidict) also reversible.)
FrozenOrderedBidict = type('FrozenOrderedBidict', _BASES, _CLSDICT) # pylint: disable=invalid-name
If you are using Python 3.8+, frozenbidict gives you everything that
FrozenOrderedBidict gives you, but with less space overhead.
"""
__slots__ = ('_hash',)
__hash__ = frozenbidict.__hash__
if _t.TYPE_CHECKING:
@property
def inverse(self) -> 'FrozenOrderedBidict[VT, KT]': ...
# Delegate to backing dicts for more efficient implementations of keys() and values().
# Possible with FrozenOrderedBidict but not OrderedBidict since FrozenOrderedBidict
# is immutable, i.e. these can't get out of sync after initialization due to mutation.
def keys(self) -> _t.KeysView[KT]:
"""A set-like object providing a view on the contained keys."""
return self._fwdm._fwdm.keys() # type: ignore [return-value]
def values(self) -> _t.KeysView[VT]: # type: ignore [override]
"""A set-like object providing a view on the contained values."""
return self._invm._fwdm.keys() # type: ignore [return-value]
# Can't delegate for items() because values in _fwdm and _invm are nodes.
# On Python 3.8+, delegate to backing dicts for a more efficient implementation
# of __iter__ and __reversed__ (both of which call this _iter() method):
if hasattr(dict, '__reversed__'):
def _iter(self, *, reverse: bool = False) -> _t.Iterator[KT]:
itfn = reversed if reverse else iter
return itfn(self._fwdm._fwdm) # type: ignore [operator,no-any-return]
else:
# On Python < 3.8, just optimize __iter__:
def _iter(self, *, reverse: bool = False) -> _t.Iterator[KT]:
if not reverse:
return iter(self._fwdm._fwdm)
return super()._iter(reverse=True)
# * Code review nav *

View file

@ -1,50 +1,56 @@
# -*- coding: utf-8 -*-
# Copyright 2009-2019 Joshua Bronson. All Rights Reserved.
# Copyright 2009-2021 Joshua Bronson. All Rights Reserved.
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
"""Useful functions for working with bidirectional mappings and related data."""
"""Functions for iterating over items in a mapping."""
from itertools import chain, repeat
import typing as _t
from collections.abc import Mapping
from itertools import chain
from .compat import iteritems, Mapping
from ._typing import KT, VT, IterItems, MapOrIterItems
_NULL_IT = repeat(None, 0) # repeat 0 times -> raise StopIteration from the start
_NULL_IT: IterItems = iter(())
def _iteritems_mapping_or_iterable(arg):
def _iteritems_mapping_or_iterable(arg: MapOrIterItems[KT, VT]) -> IterItems[KT, VT]:
"""Yield the items in *arg*.
If *arg* is a :class:`~collections.abc.Mapping`, return an iterator over its items.
Otherwise return an iterator over *arg* itself.
"""
return iteritems(arg) if isinstance(arg, Mapping) else iter(arg)
return iter(arg.items() if isinstance(arg, Mapping) else arg)
def _iteritems_args_kw(*args, **kw):
def _iteritems_args_kw(*args: MapOrIterItems[KT, VT], **kw: VT) -> IterItems[KT, VT]:
"""Yield the items from the positional argument (if given) and then any from *kw*.
:raises TypeError: if more than one positional argument is given.
"""
args_len = len(args)
if args_len > 1:
raise TypeError('Expected at most 1 positional argument, got %d' % args_len)
itemchain = None
raise TypeError(f'Expected at most 1 positional argument, got {args_len}')
it: IterItems = ()
if args:
arg = args[0]
if arg:
itemchain = _iteritems_mapping_or_iterable(arg)
it = _iteritems_mapping_or_iterable(arg)
if kw:
iterkw = iteritems(kw)
itemchain = chain(itemchain, iterkw) if itemchain else iterkw
return itemchain or _NULL_IT
iterkw = iter(kw.items())
it = chain(it, iterkw) if it else iterkw
return it or _NULL_IT
def inverted(arg):
@_t.overload
def inverted(arg: _t.Mapping[KT, VT]) -> IterItems[VT, KT]: ...
@_t.overload
def inverted(arg: IterItems[KT, VT]) -> IterItems[VT, KT]: ...
def inverted(arg: MapOrIterItems[KT, VT]) -> IterItems[VT, KT]:
"""Yield the inverse items of the provided object.
If *arg* has a :func:`callable` ``__inverted__`` attribute,
@ -57,5 +63,5 @@ def inverted(arg):
"""
inv = getattr(arg, '__inverted__', None)
if callable(inv):
return inv()
return inv() # type: ignore [no-any-return]
return ((val, key) for (key, val) in _iteritems_mapping_or_iterable(arg))

View file

@ -1,19 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2009-2019 Joshua Bronson. All Rights Reserved.
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
"""Provides :class:`_Marker`, an internal type for representing singletons."""
from collections import namedtuple
class _Marker(namedtuple('_Marker', 'name')):
__slots__ = ()
def __repr__(self):
return '<%s>' % self.name # pragma: no cover

View file

@ -1,14 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2009-2019 Joshua Bronson. All Rights Reserved.
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
"""Provides the :obj:`_MISS` sentinel, for internally signaling "missing/not found"."""
from ._marker import _Marker
_MISS = _Marker('MISSING')

View file

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
# Copyright 2009-2019 Joshua Bronson. All Rights Reserved.
# Copyright 2009-2021 Joshua Bronson. All Rights Reserved.
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
@ -26,32 +26,31 @@
#==============================================================================
"""Provides :class:`bidict`."""
"""Provide :class:`MutableBidict`."""
import typing as _t
from ._abc import MutableBidirectionalMapping
from ._base import BidictBase
from ._dup import OVERWRITE, RAISE, _OnDup
from ._miss import _MISS
from .compat import MutableMapping
from ._dup import OnDup, ON_DUP_RAISE, ON_DUP_DROP_OLD
from ._typing import _NONE, KT, VT, VDT, IterItems, MapOrIterItems
# Extend MutableMapping explicitly because it doesn't implement __subclasshook__, as well as to
# inherit method implementations it provides that we can reuse (namely `setdefault`).
class MutableBidict(BidictBase, MutableMapping):
class MutableBidict(BidictBase[KT, VT], MutableBidirectionalMapping[KT, VT]):
"""Base class for mutable bidirectional mappings."""
__slots__ = ()
__hash__ = None # since this class is mutable; explicit > implicit.
if _t.TYPE_CHECKING:
@property
def inverse(self) -> 'MutableBidict[VT, KT]': ...
_ON_DUP_OVERWRITE = _OnDup(key=OVERWRITE, val=OVERWRITE, kv=OVERWRITE)
def __delitem__(self, key):
u"""*x.__delitem__(y)  del x[y]*"""
def __delitem__(self, key: KT) -> None:
"""*x.__delitem__(y)  del x[y]*"""
self._pop(key)
def __setitem__(self, key, val):
"""
Set the value for *key* to *val*.
def __setitem__(self, key: KT, val: VT) -> None:
"""Set the value for *key* to *val*.
If *key* is already associated with *val*, this is a no-op.
@ -64,7 +63,7 @@ class MutableBidict(BidictBase, MutableMapping):
to protect against accidental removal of the key
that's currently associated with *val*.
Use :meth:`put` instead if you want to specify different policy in
Use :meth:`put` instead if you want to specify different behavior in
the case that the provided key or value duplicates an existing one.
Or use :meth:`forceput` to unconditionally associate *key* with *val*,
replacing any existing items as necessary to preserve uniqueness.
@ -76,16 +75,12 @@ class MutableBidict(BidictBase, MutableMapping):
existing item and *val* duplicates the value of a different
existing item.
"""
on_dup = self._get_on_dup()
self._put(key, val, on_dup)
self._put(key, val, self.on_dup)
def put(self, key, val, on_dup_key=RAISE, on_dup_val=RAISE, on_dup_kv=None):
"""
Associate *key* with *val* with the specified duplication policies.
def put(self, key: KT, val: VT, on_dup: OnDup = ON_DUP_RAISE) -> None:
"""Associate *key* with *val*, honoring the :class:`OnDup` given in *on_dup*.
If *on_dup_kv* is ``None``, the *on_dup_val* policy will be used for it.
For example, if all given duplication policies are :attr:`~bidict.RAISE`,
For example, if *on_dup* is :attr:`~bidict.ON_DUP_RAISE`,
then *key* will be associated with *val* if and only if
*key* is not already associated with an existing value and
*val* is not already associated with an existing key,
@ -94,37 +89,39 @@ class MutableBidict(BidictBase, MutableMapping):
If *key* is already associated with *val*, this is a no-op.
:raises bidict.KeyDuplicationError: if attempting to insert an item
whose key only duplicates an existing item's, and *on_dup_key* is
whose key only duplicates an existing item's, and *on_dup.key* is
:attr:`~bidict.RAISE`.
:raises bidict.ValueDuplicationError: if attempting to insert an item
whose value only duplicates an existing item's, and *on_dup_val* is
whose value only duplicates an existing item's, and *on_dup.val* is
:attr:`~bidict.RAISE`.
:raises bidict.KeyAndValueDuplicationError: if attempting to insert an
item whose key duplicates one existing item's, and whose value
duplicates another existing item's, and *on_dup_kv* is
duplicates another existing item's, and *on_dup.kv* is
:attr:`~bidict.RAISE`.
"""
on_dup = self._get_on_dup((on_dup_key, on_dup_val, on_dup_kv))
self._put(key, val, on_dup)
def forceput(self, key, val):
"""
Associate *key* with *val* unconditionally.
def forceput(self, key: KT, val: VT) -> None:
"""Associate *key* with *val* unconditionally.
Replace any existing mappings containing key *key* or value *val*
as necessary to preserve uniqueness.
"""
self._put(key, val, self._ON_DUP_OVERWRITE)
self._put(key, val, ON_DUP_DROP_OLD)
def clear(self):
def clear(self) -> None:
"""Remove all items."""
self._fwdm.clear()
self._invm.clear()
def pop(self, key, default=_MISS):
u"""*x.pop(k[, d]) → v*
@_t.overload
def pop(self, key: KT) -> VT: ...
@_t.overload
def pop(self, key: KT, default: VDT = ...) -> VDT: ...
def pop(self, key: KT, default: VDT = _NONE) -> VDT:
"""*x.pop(k[, d]) → v*
Remove specified key and return the corresponding value.
@ -133,12 +130,12 @@ class MutableBidict(BidictBase, MutableMapping):
try:
return self._pop(key)
except KeyError:
if default is _MISS:
if default is _NONE:
raise
return default
def popitem(self):
u"""*x.popitem() → (k, v)*
def popitem(self) -> _t.Tuple[KT, VT]:
"""*x.popitem() → (k, v)*
Remove and return some item as a (key, value) pair.
@ -150,24 +147,38 @@ class MutableBidict(BidictBase, MutableMapping):
del self._invm[val]
return key, val
def update(self, *args, **kw):
"""Like :meth:`putall` with default duplication policies."""
@_t.overload
def update(self, __arg: _t.Mapping[KT, VT], **kw: VT) -> None: ...
@_t.overload
def update(self, __arg: IterItems[KT, VT], **kw: VT) -> None: ...
@_t.overload
def update(self, **kw: VT) -> None: ...
def update(self, *args: MapOrIterItems[KT, VT], **kw: VT) -> None:
"""Like calling :meth:`putall` with *self.on_dup* passed for *on_dup*."""
if args or kw:
self._update(False, None, *args, **kw)
self._update(False, self.on_dup, *args, **kw)
def forceupdate(self, *args, **kw):
@_t.overload
def forceupdate(self, __arg: _t.Mapping[KT, VT], **kw: VT) -> None: ...
@_t.overload
def forceupdate(self, __arg: IterItems[KT, VT], **kw: VT) -> None: ...
@_t.overload
def forceupdate(self, **kw: VT) -> None: ...
def forceupdate(self, *args: MapOrIterItems[KT, VT], **kw: VT) -> None:
"""Like a bulk :meth:`forceput`."""
self._update(False, self._ON_DUP_OVERWRITE, *args, **kw)
self._update(False, ON_DUP_DROP_OLD, *args, **kw)
def putall(self, items, on_dup_key=RAISE, on_dup_val=RAISE, on_dup_kv=None):
"""
Like a bulk :meth:`put`.
@_t.overload
def putall(self, items: _t.Mapping[KT, VT], on_dup: OnDup) -> None: ...
@_t.overload
def putall(self, items: IterItems[KT, VT], on_dup: OnDup = ON_DUP_RAISE) -> None: ...
def putall(self, items: MapOrIterItems[KT, VT], on_dup: OnDup = ON_DUP_RAISE) -> None:
"""Like a bulk :meth:`put`.
If one of the given items causes an exception to be raised,
none of the items is inserted.
"""
if items:
on_dup = self._get_on_dup((on_dup_key, on_dup_val, on_dup_kv))
self._update(False, on_dup, items)

View file

@ -1,34 +1,35 @@
# -*- coding: utf-8 -*-
# Copyright 2009-2019 Joshua Bronson. All Rights Reserved.
# Copyright 2009-2021 Joshua Bronson. All Rights Reserved.
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
"""Provides :func:`bidict.namedbidict`."""
"""Provide :func:`bidict.namedbidict`."""
import re
import typing as _t
from sys import _getframe
from ._abc import BidirectionalMapping
from ._abc import BidirectionalMapping, KT, VT
from ._bidict import bidict
from .compat import PY2
_isidentifier = ( # pylint: disable=invalid-name
re.compile('[A-Za-z_][A-Za-z0-9_]*$').match if PY2 else str.isidentifier
)
def namedbidict(typename, keyname, valname, base_type=bidict):
def namedbidict(
typename: str,
keyname: str,
valname: str,
*,
base_type: _t.Type[BidirectionalMapping[KT, VT]] = bidict,
) -> _t.Type[BidirectionalMapping[KT, VT]]:
r"""Create a new subclass of *base_type* with custom accessors.
Analagous to :func:`collections.namedtuple`.
Like :func:`collections.namedtuple` for bidicts.
The new class's ``__name__`` and ``__qualname__``
will be set based on *typename*.
The new class's ``__name__`` and ``__qualname__`` will be set to *typename*,
and its ``__module__`` will be set to the caller's module.
Instances of it will provide access to their
:attr:`inverse <BidirectionalMapping.inverse>`\s
Instances of the new class will provide access to their
:attr:`inverse <BidirectionalMapping.inverse>` instances
via the custom *keyname*\_for property,
and access to themselves
via the custom *valname*\_for property.
@ -39,63 +40,58 @@ def namedbidict(typename, keyname, valname, base_type=bidict):
:raises ValueError: if any of the *typename*, *keyname*, or *valname*
strings is not a valid Python identifier, or if *keyname == valname*.
:raises TypeError: if *base_type* is not a subclass of
:class:`BidirectionalMapping`.
(This function requires slightly more of *base_type*,
e.g. the availability of an ``_isinv`` attribute,
but all the :ref:`concrete bidict types
<other-bidict-types:Bidict Types Diagram>`
that the :mod:`bidict` module provides can be passed in.
Check out the code if you actually need to pass in something else.)
:raises TypeError: if *base_type* is not a :class:`BidirectionalMapping` subclass
that provides ``_isinv`` and :meth:`~object.__getstate__` attributes.
(Any :class:`~bidict.BidictBase` subclass can be passed in, including all the
concrete bidict types pictured in the :ref:`other-bidict-types:Bidict Types Diagram`.
"""
# Re the `base_type` docs above:
# The additional requirements (providing _isinv and __getstate__) do not belong in the
# BidirectionalMapping interface, and it's overkill to create additional interface(s) for this.
# On the other hand, it's overkill to require that base_type be a subclass of BidictBase, since
# that's too specific. The BidirectionalMapping check along with the docs above should suffice.
if not issubclass(base_type, BidirectionalMapping):
if not issubclass(base_type, BidirectionalMapping) or not all(hasattr(base_type, i) for i in ('_isinv', '__getstate__')):
raise TypeError(base_type)
names = (typename, keyname, valname)
if not all(map(_isidentifier, names)) or keyname == valname:
if not all(map(str.isidentifier, names)) or keyname == valname:
raise ValueError(names)
class _Named(base_type): # pylint: disable=too-many-ancestors
class _Named(base_type): # type: ignore [valid-type,misc]
__slots__ = ()
def _getfwd(self):
return self.inverse if self._isinv else self
def _getfwd(self) -> '_Named':
return self.inverse if self._isinv else self # type: ignore [no-any-return]
def _getinv(self):
return self if self._isinv else self.inverse
def _getinv(self) -> '_Named':
return self if self._isinv else self.inverse # type: ignore [no-any-return]
@property
def _keyname(self):
def _keyname(self) -> str:
return valname if self._isinv else keyname
@property
def _valname(self):
def _valname(self) -> str:
return keyname if self._isinv else valname
def __reduce__(self):
def __reduce__(self) -> '_t.Tuple[_t.Callable[[str, str, str, _t.Type[BidirectionalMapping]], BidirectionalMapping], _t.Tuple[str, str, str, _t.Type[BidirectionalMapping]], dict]':
return (_make_empty, (typename, keyname, valname, base_type), self.__getstate__())
bname = base_type.__name__
fname = valname + '_for'
iname = keyname + '_for'
names = dict(typename=typename, bname=bname, keyname=keyname, valname=valname)
fdoc = u'{typename} forward {bname}: {keyname}{valname}'.format(**names)
idoc = u'{typename} inverse {bname}: {valname}{keyname}'.format(**names)
setattr(_Named, fname, property(_Named._getfwd, doc=fdoc)) # pylint: disable=protected-access
setattr(_Named, iname, property(_Named._getinv, doc=idoc)) # pylint: disable=protected-access
fdoc = f'{typename} forward {bname}: {keyname}{valname}'
idoc = f'{typename} inverse {bname}: {valname}{keyname}'
setattr(_Named, fname, property(_Named._getfwd, doc=fdoc))
setattr(_Named, iname, property(_Named._getinv, doc=idoc))
if not PY2:
_Named.__qualname__ = _Named.__qualname__[:-len(_Named.__name__)] + typename
_Named.__name__ = typename
_Named.__qualname__ = typename
_Named.__module__ = _getframe(1).f_globals.get('__name__') # type: ignore [assignment]
return _Named
def _make_empty(typename, keyname, valname, base_type):
def _make_empty(
typename: str,
keyname: str,
valname: str,
base_type: _t.Type[BidirectionalMapping] = bidict,
) -> BidirectionalMapping:
"""Create a named bidict with the indicated arguments and return an empty instance.
Used to make :func:`bidict.namedbidict` instances picklable.
"""

View file

@ -1,14 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2009-2019 Joshua Bronson. All Rights Reserved.
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
"""Provides the :obj:`_NOOP` sentinel, for internally signaling "no-op"."""
from ._marker import _Marker
_NOOP = _Marker('NO-OP')

View file

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
# Copyright 2009-2019 Joshua Bronson. All Rights Reserved.
# Copyright 2009-2021 Joshua Bronson. All Rights Reserved.
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
@ -26,17 +26,19 @@
#==============================================================================
"""Provides :class:`OrderedBidictBase`."""
"""Provide :class:`OrderedBidictBase`."""
import typing as _t
from copy import copy
from weakref import ref
from ._base import _WriteResult, BidictBase
from ._abc import MutableBidirectionalMapping
from ._base import _NONE, _DedupResult, _WriteResult, BidictBase, BT
from ._bidict import bidict
from ._miss import _MISS
from .compat import Mapping, PY2, iteritems, izip
from ._typing import KT, VT, OKT, OVT, IterItems, MapOrIterItems
class _Node(object): # pylint: disable=too-few-public-methods
class _Node:
"""A node in a circular doubly-linked list
used to encode the order of items in an ordered bidict.
@ -55,33 +57,33 @@ class _Node(object): # pylint: disable=too-few-public-methods
__slots__ = ('_prv', '_nxt', '__weakref__')
def __init__(self, prv=None, nxt=None):
def __init__(self, prv: '_Node' = None, nxt: '_Node' = None) -> None:
self._setprv(prv)
self._setnxt(nxt)
def __repr__(self): # pragma: no cover
def __repr__(self) -> str:
clsname = self.__class__.__name__
prv = id(self.prv)
nxt = id(self.nxt)
return '%s(prv=%s, self=%s, nxt=%s)' % (clsname, prv, id(self), nxt)
return f'{clsname}(prv={prv}, self={id(self)}, nxt={nxt})'
def _getprv(self):
def _getprv(self) -> '_t.Optional[_Node]':
return self._prv() if isinstance(self._prv, ref) else self._prv
def _setprv(self, prv):
def _setprv(self, prv: '_t.Optional[_Node]') -> None:
self._prv = prv and ref(prv)
prv = property(_getprv, _setprv)
def _getnxt(self):
def _getnxt(self) -> '_t.Optional[_Node]':
return self._nxt() if isinstance(self._nxt, ref) else self._nxt
def _setnxt(self, nxt):
def _setnxt(self, nxt: '_t.Optional[_Node]') -> None:
self._nxt = nxt and ref(nxt)
nxt = property(_getnxt, _setnxt)
def __getstate__(self):
def __getstate__(self) -> dict:
"""Return the instance state dictionary
but with weakrefs converted to strong refs
so that it can be pickled.
@ -90,13 +92,13 @@ class _Node(object): # pylint: disable=too-few-public-methods
"""
return dict(_prv=self.prv, _nxt=self.nxt)
def __setstate__(self, state):
def __setstate__(self, state: dict) -> None:
"""Set the instance state from *state*."""
self._setprv(state['_prv'])
self._setnxt(state['_nxt'])
class _Sentinel(_Node): # pylint: disable=too-few-public-methods
class _SentinelNode(_Node):
"""Special node in a circular doubly-linked list
that links the first node with the last node.
When its next and previous references point back to itself
@ -105,19 +107,16 @@ class _Sentinel(_Node): # pylint: disable=too-few-public-methods
__slots__ = ()
def __init__(self, prv=None, nxt=None):
super(_Sentinel, self).__init__(prv or self, nxt or self)
def __init__(self, prv: _Node = None, nxt: _Node = None) -> None:
super().__init__(prv or self, nxt or self)
def __repr__(self): # pragma: no cover
return '<SENTINEL>'
def __repr__(self) -> str:
return '<SNTL>'
def __bool__(self):
def __bool__(self) -> bool:
return False
if PY2:
__nonzero__ = __bool__
def __iter__(self, reverse=False):
def _iter(self, *, reverse: bool = False) -> _t.Iterator[_Node]:
"""Iterator yielding nodes in the requested order,
i.e. traverse the linked list via :attr:`nxt`
(or :attr:`prv` if *reverse* is truthy)
@ -130,26 +129,35 @@ class _Sentinel(_Node): # pylint: disable=too-few-public-methods
node = getattr(node, attr)
class OrderedBidictBase(BidictBase):
class OrderedBidictBase(BidictBase[KT, VT]):
"""Base class implementing an ordered :class:`BidirectionalMapping`."""
__slots__ = ('_sntl',)
_fwdm_cls = bidict
_invm_cls = bidict
_fwdm_cls: _t.Type[MutableBidirectionalMapping[KT, _Node]] = bidict # type: ignore [assignment]
_invm_cls: _t.Type[MutableBidirectionalMapping[VT, _Node]] = bidict # type: ignore [assignment]
_fwdm: bidict[KT, _Node] # type: ignore [assignment]
_invm: bidict[VT, _Node] # type: ignore [assignment]
#: The object used by :meth:`__repr__` for printing the contained items.
_repr_delegate = list
def __init__(self, *args, **kw):
@_t.overload
def __init__(self, __arg: _t.Mapping[KT, VT], **kw: VT) -> None: ...
@_t.overload
def __init__(self, __arg: IterItems[KT, VT], **kw: VT) -> None: ...
@_t.overload
def __init__(self, **kw: VT) -> None: ...
def __init__(self, *args: MapOrIterItems[KT, VT], **kw: VT) -> None:
"""Make a new ordered bidirectional mapping.
The signature is the same as that of regular dictionaries.
The signature behaves like that of :class:`dict`.
Items passed in are added in the order they are passed,
respecting this bidict type's duplication policies along the way.
respecting the :attr:`on_dup` class attribute in the process.
The order in which items are inserted is remembered,
similar to :class:`collections.OrderedDict`.
"""
self._sntl = _Sentinel()
self._sntl = _SentinelNode()
# Like unordered bidicts, ordered bidicts also store two backing one-directional mappings
# `_fwdm` and `_invm`. But rather than mapping `key` to `val` and `val` to `key`
@ -159,55 +167,58 @@ class OrderedBidictBase(BidictBase):
# To effect this difference, `_write_item` and `_undo_write` are overridden. But much of the
# rest of BidictBase's implementation, including BidictBase.__init__ and BidictBase._update,
# are inherited and are able to be reused without modification.
super(OrderedBidictBase, self).__init__(*args, **kw)
super().__init__(*args, **kw)
def _init_inv(self):
super(OrderedBidictBase, self)._init_inv()
self.inverse._sntl = self._sntl # pylint: disable=protected-access
if _t.TYPE_CHECKING:
@property
def inverse(self) -> 'OrderedBidictBase[VT, KT]': ...
def _init_inv(self) -> None:
super()._init_inv()
self.inverse._sntl = self._sntl
# Can't reuse BidictBase.copy since ordered bidicts have different internal structure.
def copy(self):
def copy(self: BT) -> BT:
"""A shallow copy of this ordered bidict."""
# Fast copy implementation bypassing __init__. See comments in :meth:`BidictBase.copy`.
copy = self.__class__.__new__(self.__class__)
sntl = _Sentinel()
fwdm = self._fwdm.copy()
invm = self._invm.copy()
cp: BT = self.__class__.__new__(self.__class__)
sntl = _SentinelNode()
fwdm = copy(self._fwdm)
invm = copy(self._invm)
cur = sntl
nxt = sntl.nxt
for (key, val) in iteritems(self):
for (key, val) in self.items():
nxt = _Node(cur, sntl)
cur.nxt = fwdm[key] = invm[val] = nxt
cur = nxt
sntl.prv = nxt
copy._sntl = sntl # pylint: disable=protected-access
copy._fwdm = fwdm # pylint: disable=protected-access
copy._invm = invm # pylint: disable=protected-access
copy._init_inv() # pylint: disable=protected-access
return copy
cp._sntl = sntl # type: ignore [attr-defined]
cp._fwdm = fwdm
cp._invm = invm
cp._init_inv()
return cp
def __getitem__(self, key):
__copy__ = copy
def __getitem__(self, key: KT) -> VT:
nodefwd = self._fwdm[key]
val = self._invm.inverse[nodefwd]
return val
def _pop(self, key):
def _pop(self, key: KT) -> VT:
nodefwd = self._fwdm.pop(key)
val = self._invm.inverse.pop(nodefwd)
nodefwd.prv.nxt = nodefwd.nxt
nodefwd.nxt.prv = nodefwd.prv
return val
def _isdupitem(self, key, val, dedup_result):
"""Return whether (key, val) duplicates an existing item."""
isdupkey, isdupval, nodeinv, nodefwd = dedup_result
isdupitem = nodeinv is nodefwd
if isdupitem:
assert isdupkey
assert isdupval
return isdupitem
@staticmethod
def _already_have(key: KT, val: VT, nodeinv: _Node, nodefwd: _Node) -> bool: # type: ignore [override]
# Overrides _base.BidictBase.
return nodeinv is nodefwd
def _write_item(self, key, val, dedup_result): # pylint: disable=too-many-locals
def _write_item(self, key: KT, val: VT, dedup_result: _DedupResult) -> _WriteResult:
# Overrides _base.BidictBase.
fwdm = self._fwdm # bidict mapping keys to nodes
invm = self._invm # bidict mapping vals to nodes
isdupkey, isdupval, nodeinv, nodefwd = dedup_result
@ -217,7 +228,8 @@ class OrderedBidictBase(BidictBase):
last = sntl.prv
node = _Node(last, sntl)
last.nxt = sntl.prv = fwdm[key] = invm[val] = node
oldkey = oldval = _MISS
oldkey: OKT = _NONE
oldval: OVT = _NONE
elif isdupkey and isdupval:
# Key and value duplication across two different nodes.
assert nodefwd is not nodeinv
@ -239,19 +251,19 @@ class OrderedBidictBase(BidictBase):
fwdm[key] = invm[val] = nodefwd
elif isdupkey:
oldval = invm.inverse[nodefwd]
oldkey = _MISS
oldkey = _NONE
oldnodeinv = invm.pop(oldval)
assert oldnodeinv is nodefwd
invm[val] = nodefwd
else: # isdupval
oldkey = fwdm.inverse[nodeinv]
oldval = _MISS
oldval = _NONE
oldnodefwd = fwdm.pop(oldkey)
assert oldnodefwd is nodeinv
fwdm[key] = nodeinv
return _WriteResult(key, val, oldkey, oldval)
def _undo_write(self, dedup_result, write_result): # pylint: disable=too-many-locals
def _undo_write(self, dedup_result: _DedupResult, write_result: _WriteResult) -> None:
fwdm = self._fwdm
invm = self._invm
isdupkey, isdupval, nodeinv, nodefwd = dedup_result
@ -274,26 +286,18 @@ class OrderedBidictBase(BidictBase):
fwdm[oldkey] = nodeinv
assert invm[val] is nodeinv
def __iter__(self, reverse=False):
"""An iterator over this bidict's items in order."""
def __iter__(self) -> _t.Iterator[KT]:
"""Iterator over the contained keys in insertion order."""
return self._iter()
def _iter(self, *, reverse: bool = False) -> _t.Iterator[KT]:
fwdm_inv = self._fwdm.inverse
for node in self._sntl.__iter__(reverse=reverse):
for node in self._sntl._iter(reverse=reverse):
yield fwdm_inv[node]
def __reversed__(self):
"""An iterator over this bidict's items in reverse order."""
for key in self.__iter__(reverse=True):
yield key
def equals_order_sensitive(self, other):
"""Order-sensitive equality check.
*See also* :ref:`eq-order-insensitive`
"""
# Same short-circuit as BidictBase.__eq__. Factoring out not worth function call overhead.
if not isinstance(other, Mapping) or len(self) != len(other):
return False
return all(i == j for (i, j) in izip(iteritems(self), iteritems(other)))
def __reversed__(self) -> _t.Iterator[KT]:
"""Iterator over the contained keys in reverse insertion order."""
yield from self._iter(reverse=True)
# * Code review nav *

View file

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
# Copyright 2009-2019 Joshua Bronson. All Rights Reserved.
# Copyright 2009-2021 Joshua Bronson. All Rights Reserved.
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
@ -26,26 +26,32 @@
#==============================================================================
"""Provides :class:`OrderedBidict`."""
"""Provide :class:`OrderedBidict`."""
import typing as _t
from ._mut import MutableBidict
from ._orderedbase import OrderedBidictBase
from ._typing import KT, VT
class OrderedBidict(OrderedBidictBase, MutableBidict):
class OrderedBidict(OrderedBidictBase[KT, VT], MutableBidict[KT, VT]):
"""Mutable bidict type that maintains items in insertion order."""
__slots__ = ()
__hash__ = None # since this class is mutable; explicit > implicit.
def clear(self):
if _t.TYPE_CHECKING:
@property
def inverse(self) -> 'OrderedBidict[VT, KT]': ...
def clear(self) -> None:
"""Remove all items."""
self._fwdm.clear()
self._invm.clear()
self._sntl.nxt = self._sntl.prv = self._sntl
def popitem(self, last=True): # pylint: disable=arguments-differ
u"""*x.popitem() → (k, v)*
def popitem(self, last: bool = True) -> _t.Tuple[KT, VT]:
"""*x.popitem() → (k, v)*
Remove and return the most recently added item as a (key, value) pair
if *last* is True, else the least recently added item.
@ -54,11 +60,13 @@ class OrderedBidict(OrderedBidictBase, MutableBidict):
"""
if not self:
raise KeyError('mapping is empty')
key = next((reversed if last else iter)(self))
itfn: _t.Callable = reversed if last else iter # type: ignore [assignment]
it = itfn(self)
key = next(it)
val = self._pop(key)
return key, val
def move_to_end(self, key, last=True):
def move_to_end(self, key: KT, last: bool = True) -> None:
"""Move an existing key to the beginning or end of this ordered bidict.
The item is moved to the end if *last* is True, else to the beginning.
@ -70,15 +78,15 @@ class OrderedBidict(OrderedBidictBase, MutableBidict):
node.nxt.prv = node.prv
sntl = self._sntl
if last:
last = sntl.prv
node.prv = last
lastnode = sntl.prv
node.prv = lastnode
node.nxt = sntl
sntl.prv = last.nxt = node
sntl.prv = lastnode.nxt = node
else:
first = sntl.nxt
firstnode = sntl.nxt
node.prv = sntl
node.nxt = first
sntl.nxt = first.prv = node
node.nxt = firstnode
sntl.nxt = firstnode.prv = node
# * Code review nav *

33
libs/bidict/_typing.py Normal file
View file

@ -0,0 +1,33 @@
# -*- coding: utf-8 -*-
# Copyright 2009-2021 Joshua Bronson. All Rights Reserved.
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
"""Provide typing-related objects."""
import typing as _t
KT = _t.TypeVar('KT')
VT = _t.TypeVar('VT')
IterItems = _t.Iterable[_t.Tuple[KT, VT]]
MapOrIterItems = _t.Union[_t.Mapping[KT, VT], IterItems[KT, VT]]
DT = _t.TypeVar('DT') #: for default arguments
VDT = _t.Union[VT, DT]
class _BareReprMeta(type):
def __repr__(cls) -> str:
return f'<{cls.__name__}>'
class _NONE(metaclass=_BareReprMeta):
"""Sentinel type used to represent 'missing'."""
OKT = _t.Union[KT, _NONE] #: optional key type
OVT = _t.Union[VT, _NONE] #: optional value type

View file

@ -1,78 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2009-2019 Joshua Bronson. All Rights Reserved.
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
"""Compatibility helpers."""
from operator import methodcaller
from platform import python_implementation
from sys import version_info
from warnings import warn
# Use #: (before or) at the end of each line with a member we want to show up in the docs,
# otherwise Sphinx won't include (even though we configure automodule with undoc-members).
PYMAJOR, PYMINOR = version_info[:2] #:
PY2 = PYMAJOR == 2 #:
PYIMPL = python_implementation() #:
CPY = PYIMPL == 'CPython' #:
PYPY = PYIMPL == 'PyPy' #:
DICTS_ORDERED = PYPY or (CPY and (PYMAJOR, PYMINOR) >= (3, 6)) #:
# Without the following, pylint gives lots of false positives.
# pylint: disable=invalid-name,unused-import,ungrouped-imports,no-name-in-module
if PY2:
if PYMINOR < 7: # pragma: no cover
raise ImportError('Python 2.7 or 3.5+ is required.')
warn('Python 2 support will be dropped in a future release.')
# abstractproperty deprecated in Python 3.3 in favor of using @property with @abstractmethod.
# Before 3.3, this silently fails to detect when an abstract property has not been overridden.
from abc import abstractproperty #:
from itertools import izip #:
# In Python 3, the collections ABCs were moved into collections.abc, which does not exist in
# Python 2. Support for importing them directly from collections is dropped in Python 3.8.
import collections as collections_abc # noqa: F401 (imported but unused)
from collections import ( # noqa: F401 (imported but unused)
Mapping, MutableMapping, KeysView, ValuesView, ItemsView)
viewkeys = lambda m: m.viewkeys() if hasattr(m, 'viewkeys') else KeysView(m) #:
viewvalues = lambda m: m.viewvalues() if hasattr(m, 'viewvalues') else ValuesView(m) #:
viewitems = lambda m: m.viewitems() if hasattr(m, 'viewitems') else ItemsView(m) #:
iterkeys = lambda m: m.iterkeys() if hasattr(m, 'iterkeys') else iter(m.keys()) #:
itervalues = lambda m: m.itervalues() if hasattr(m, 'itervalues') else iter(m.values()) #:
iteritems = lambda m: m.iteritems() if hasattr(m, 'iteritems') else iter(m.items()) #:
else:
# Assume Python 3 when not PY2, but explicitly check before showing this warning.
if PYMAJOR == 3 and PYMINOR < 5: # pragma: no cover
warn('Python 3.4 and below are not supported.')
import collections.abc as collections_abc # noqa: F401 (imported but unused)
from collections.abc import ( # noqa: F401 (imported but unused)
Mapping, MutableMapping, KeysView, ValuesView, ItemsView)
viewkeys = methodcaller('keys') #:
viewvalues = methodcaller('values') #:
viewitems = methodcaller('items') #:
def _compose(f, g):
return lambda x: f(g(x))
iterkeys = _compose(iter, viewkeys) #:
itervalues = _compose(iter, viewvalues) #:
iteritems = _compose(iter, viewitems) #:
from abc import abstractmethod
abstractproperty = _compose(property, abstractmethod) #:
izip = zip #:

View file

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
# Copyright 2009-2019 Joshua Bronson. All Rights Reserved.
# Copyright 2009-2021 Joshua Bronson. All Rights Reserved.
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
@ -8,42 +8,22 @@
"""Define bidict package metadata."""
__version__ = '0.0.0.VERSION_NOT_FOUND'
# _version.py is generated by setuptools_scm (via its `write_to` param, see setup.py)
try:
from ._version import version as __version__ # pylint: disable=unused-import
except (ImportError, ValueError, SystemError): # pragma: no cover
try:
import pkg_resources
except ImportError:
pass
else:
try:
__version__ = pkg_resources.get_distribution('bidict').version
except pkg_resources.DistributionNotFound:
pass
try:
__version_info__ = tuple(int(p) if i < 3 else p for (i, p) in enumerate(__version__.split('.')))
except Exception: # noqa: E722; pragma: no cover; pylint: disable=broad-except
__vesion_info__ = (0, 0, 0, 'PARSE FAILURE: __version__=%s' % __version__)
__author__ = u'Joshua Bronson'
__maintainer__ = u'Joshua Bronson'
__copyright__ = u'Copyright 2019 Joshua Bronson'
__email__ = u'jab@math.brown.edu'
__version__ = '0.21.4'
__author__ = 'Joshua Bronson'
__maintainer__ = 'Joshua Bronson'
__copyright__ = 'Copyright 2009-2021 Joshua Bronson'
__email__ = 'jabronson@gmail.com'
# See: ../docs/thanks.rst
__credits__ = [i.strip() for i in u"""
__credits__ = [i.strip() for i in """
Joshua Bronson, Michael Arntzenius, Francis Carr, Gregory Ewing, Raymond Hettinger, Jozef Knaperek,
Daniel Pope, Terry Reedy, David Turner, Tom Viner, Richard Sanger, Zeyi Wang
""".split(u',')]
""".split(',')]
__description__ = u'Efficient, Pythonic bidirectional map implementation and related functionality'
__description__ = 'The bidirectional mapping library for Python.'
__keywords__ = 'dict dictionary mapping datastructure bimap bijection bijective ' \
'injective inverse reverse bidirectional two-way 2-way'
__license__ = u'MPL 2.0'
__status__ = u'Beta'
__url__ = u'https://bidict.readthedocs.io'
__license__ = 'MPL 2.0'
__status__ = 'Beta'
__url__ = 'https://bidict.readthedocs.io'

View file

@ -1,6 +1,5 @@
"""Beautiful Soup
Elixir and Tonic
"The Screen-Scraper's Friend"
"""Beautiful Soup Elixir and Tonic - "The Screen-Scraper's Friend".
http://www.crummy.com/software/BeautifulSoup/
Beautiful Soup uses a pluggable XML or HTML parser to parse a
@ -8,29 +7,34 @@ Beautiful Soup uses a pluggable XML or HTML parser to parse a
provides methods and Pythonic idioms that make it easy to navigate,
search, and modify the parse tree.
Beautiful Soup works with Python 2.7 and up. It works better if lxml
Beautiful Soup works with Python 3.5 and up. It works better if lxml
and/or html5lib is installed.
For more than you ever wanted to know about Beautiful Soup, see the
documentation:
http://www.crummy.com/software/BeautifulSoup/bs4/doc/
documentation: http://www.crummy.com/software/BeautifulSoup/bs4/doc/
"""
__author__ = "Leonard Richardson (leonardr@segfault.org)"
__version__ = "4.8.0"
__copyright__ = "Copyright (c) 2004-2019 Leonard Richardson"
__version__ = "4.10.0"
__copyright__ = "Copyright (c) 2004-2021 Leonard Richardson"
# Use of this source code is governed by the MIT license.
__license__ = "MIT"
__all__ = ['BeautifulSoup']
from collections import Counter
import os
import re
import sys
import traceback
import warnings
# The very first thing we do is give a useful error if someone is
# running this code under Python 2.
if sys.version_info.major < 3:
raise ImportError('You are trying to use a Python 3-specific version of Beautiful Soup under Python 2. This will not work. The final version of Beautiful Soup to support Python 2 was 4.9.3.')
from .builder import builder_registry, ParserRejectedMarkup
from .dammit import UnicodeDammit
from .element import (
@ -42,28 +46,49 @@ from .element import (
NavigableString,
PageElement,
ProcessingInstruction,
PYTHON_SPECIFIC_ENCODINGS,
ResultSet,
Script,
Stylesheet,
SoupStrainer,
Tag,
TemplateString,
)
# The very first thing we do is give a useful error if someone is
# running this code under Python 3 without converting it.
'You are trying to run the Python 2 version of Beautiful Soup under Python 3. This will not work.'!='You need to convert the code, either by installing it (`python setup.py install`) or by running 2to3 (`2to3 -w bs4`).'
# Define some custom warnings.
class GuessedAtParserWarning(UserWarning):
"""The warning issued when BeautifulSoup has to guess what parser to
use -- probably because no parser was specified in the constructor.
"""
class MarkupResemblesLocatorWarning(UserWarning):
"""The warning issued when BeautifulSoup is given 'markup' that
actually looks like a resource locator -- a URL or a path to a file
on disk.
"""
class BeautifulSoup(Tag):
"""
This class defines the basic interface called by the tree builders.
"""A data structure representing a parsed HTML or XML document.
These methods will be called by the parser:
reset()
feed(markup)
Most of the methods you'll call on a BeautifulSoup object are inherited from
PageElement or Tag.
Internally, this class defines the basic interface called by the
tree builders when converting an HTML/XML document into a data
structure. The interface abstracts away the differences between
parsers. To write a new tree builder, you'll need to understand
these methods as a whole.
These methods will be called by the BeautifulSoup constructor:
* reset()
* feed(markup)
The tree builder may call these methods from its feed() implementation:
handle_starttag(name, attrs) # See note about return value
handle_endtag(name)
handle_data(data) # Appends to the current data node
endData(containerClass=NavigableString) # Ends the current data node
* handle_starttag(name, attrs) # See note about return value
* handle_endtag(name)
* handle_data(data) # Appends to the current data node
* endData(containerClass) # Ends the current data node
No matter how complicated the underlying parser is, you should be
able to build a tree using 'start tag' events, 'end tag' events,
@ -73,62 +98,75 @@ class BeautifulSoup(Tag):
like HTML's <br> tag), call handle_starttag and then
handle_endtag.
"""
# Since BeautifulSoup subclasses Tag, it's possible to treat it as
# a Tag with a .name. This name makes it clear the BeautifulSoup
# object isn't a real markup tag.
ROOT_TAG_NAME = '[document]'
# If the end-user gives no indication which tree builder they
# want, look for one with these features.
DEFAULT_BUILDER_FEATURES = ['html', 'fast']
# A string containing all ASCII whitespace characters, used in
# endData() to detect data chunks that seem 'empty'.
ASCII_SPACES = '\x20\x0a\x09\x0c\x0d'
NO_PARSER_SPECIFIED_WARNING = "No parser was explicitly specified, so I'm using the best available %(markup_type)s parser for this system (\"%(parser)s\"). This usually isn't a problem, but if you run this code on another system, or in a different virtual environment, it may use a different parser and behave differently.\n\nThe code that caused this warning is on line %(line_number)s of the file %(filename)s. To get rid of this warning, pass the additional argument 'features=\"%(parser)s\"' to the BeautifulSoup constructor.\n"
def __init__(self, markup="", features=None, builder=None,
parse_only=None, from_encoding=None, exclude_encodings=None,
**kwargs):
element_classes=None, **kwargs):
"""Constructor.
:param markup: A string or a file-like object representing
markup to be parsed.
markup to be parsed.
:param features: Desirable features of the parser to be used. This
may be the name of a specific parser ("lxml", "lxml-xml",
"html.parser", or "html5lib") or it may be the type of markup
to be used ("html", "html5", "xml"). It's recommended that you
name a specific parser, so that Beautiful Soup gives you the
same results across platforms and virtual environments.
:param features: Desirable features of the parser to be
used. This may be the name of a specific parser ("lxml",
"lxml-xml", "html.parser", or "html5lib") or it may be the
type of markup to be used ("html", "html5", "xml"). It's
recommended that you name a specific parser, so that
Beautiful Soup gives you the same results across platforms
and virtual environments.
:param builder: A TreeBuilder subclass to instantiate (or
instance to use) instead of looking one up based on
`features`. You only need to use this if you've implemented a
custom TreeBuilder.
instance to use) instead of looking one up based on
`features`. You only need to use this if you've implemented a
custom TreeBuilder.
:param parse_only: A SoupStrainer. Only parts of the document
matching the SoupStrainer will be considered. This is useful
when parsing part of a document that would otherwise be too
large to fit into memory.
matching the SoupStrainer will be considered. This is useful
when parsing part of a document that would otherwise be too
large to fit into memory.
:param from_encoding: A string indicating the encoding of the
document to be parsed. Pass this in if Beautiful Soup is
guessing wrongly about the document's encoding.
document to be parsed. Pass this in if Beautiful Soup is
guessing wrongly about the document's encoding.
:param exclude_encodings: A list of strings indicating
encodings known to be wrong. Pass this in if you don't know
the document's encoding but you know Beautiful Soup's guess is
wrong.
encodings known to be wrong. Pass this in if you don't know
the document's encoding but you know Beautiful Soup's guess is
wrong.
:param element_classes: A dictionary mapping BeautifulSoup
classes like Tag and NavigableString, to other classes you'd
like to be instantiated instead as the parse tree is
built. This is useful for subclassing Tag or NavigableString
to modify default behavior.
:param kwargs: For backwards compatibility purposes, the
constructor accepts certain keyword arguments used in
Beautiful Soup 3. None of these arguments do anything in
Beautiful Soup 4; they will result in a warning and then be ignored.
Apart from this, any keyword arguments passed into the BeautifulSoup
constructor are propagated to the TreeBuilder constructor. This
makes it possible to configure a TreeBuilder beyond saying
which one to use.
constructor accepts certain keyword arguments used in
Beautiful Soup 3. None of these arguments do anything in
Beautiful Soup 4; they will result in a warning and then be
ignored.
Apart from this, any keyword arguments passed into the
BeautifulSoup constructor are propagated to the TreeBuilder
constructor. This makes it possible to configure a
TreeBuilder by passing in arguments, not just by saying which
one to use.
"""
if 'convertEntities' in kwargs:
del kwargs['convertEntities']
warnings.warn(
@ -185,6 +223,8 @@ class BeautifulSoup(Tag):
warnings.warn("You provided Unicode markup but also provided a value for from_encoding. Your from_encoding will be ignored.")
from_encoding = None
self.element_classes = element_classes or dict()
# We need this information to track whether or not the builder
# was specified well enough that we can omit the 'you need to
# specify a parser' warning.
@ -215,7 +255,9 @@ class BeautifulSoup(Tag):
if not original_builder and not (
original_features == builder.NAME or
original_features in builder.ALTERNATE_NAMES
):
) and markup:
# The user did not tell us which TreeBuilder to use,
# and we had to guess. Issue a warning.
if builder.is_xml:
markup_type = "XML"
else:
@ -249,7 +291,10 @@ class BeautifulSoup(Tag):
parser=builder.NAME,
markup_type=markup_type
)
warnings.warn(self.NO_PARSER_SPECIFIED_WARNING % values, stacklevel=2)
warnings.warn(
self.NO_PARSER_SPECIFIED_WARNING % values,
GuessedAtParserWarning, stacklevel=2
)
else:
if kwargs:
warnings.warn("Keyword arguments to the BeautifulSoup constructor will be ignored. These would normally be passed into the TreeBuilder constructor, but a TreeBuilder instance was passed in as `builder`.")
@ -278,22 +323,36 @@ class BeautifulSoup(Tag):
else:
possible_filename = markup
is_file = False
is_directory = False
try:
is_file = os.path.exists(possible_filename)
if is_file:
is_directory = os.path.isdir(possible_filename)
except Exception as e:
# This is almost certainly a problem involving
# characters not valid in filenames on this
# system. Just let it go.
pass
if is_file:
if isinstance(markup, str):
markup = markup.encode("utf8")
if is_directory:
warnings.warn(
'"%s" looks like a directory name, not markup. You may'
' want to open a file found in this directory and pass'
' the filehandle into Beautiful Soup.' % (
self._decode_markup(markup)
),
MarkupResemblesLocatorWarning
)
elif is_file:
warnings.warn(
'"%s" looks like a filename, not markup. You should'
' probably open this file and pass the filehandle into'
' Beautiful Soup.' % markup)
' Beautiful Soup.' % self._decode_markup(markup),
MarkupResemblesLocatorWarning
)
self._check_markup_is_url(markup)
rejections = []
success = False
for (self.markup, self.original_encoding, self.declared_html_encoding,
self.contains_replacement_characters) in (
self.builder.prepare_markup(
@ -301,16 +360,25 @@ class BeautifulSoup(Tag):
self.reset()
try:
self._feed()
success = True
break
except ParserRejectedMarkup:
except ParserRejectedMarkup as e:
rejections.append(e)
pass
if not success:
other_exceptions = [str(e) for e in rejections]
raise ParserRejectedMarkup(
"The markup you provided was rejected by the parser. Trying a different parser or a different encoding may help.\n\nOriginal exception(s) from parser:\n " + "\n ".join(other_exceptions)
)
# Clear out the markup and remove the builder's circular
# reference to this object.
self.markup = None
self.builder.soup = None
def __copy__(self):
"""Copy a BeautifulSoup object by converting the document to a string and parsing it again."""
copy = type(self)(
self.encode('utf-8'), builder=self.builder, from_encoding='utf-8'
)
@ -329,11 +397,25 @@ class BeautifulSoup(Tag):
d['builder'] = None
return d
@staticmethod
def _check_markup_is_url(markup):
"""
Check if markup looks like it's actually a url and raise a warning
if so. Markup can be unicode or str (py2) / bytes (py3).
@classmethod
def _decode_markup(cls, markup):
"""Ensure `markup` is bytes so it's safe to send into warnings.warn.
TODO: warnings.warn had this problem back in 2010 but it might not
anymore.
"""
if isinstance(markup, bytes):
decoded = markup.decode('utf-8', 'replace')
else:
decoded = markup
return decoded
@classmethod
def _check_markup_is_url(cls, markup):
"""Error-handling method to raise a warning if incoming markup looks
like a URL.
:param markup: A string.
"""
if isinstance(markup, bytes):
space = b' '
@ -346,18 +428,20 @@ class BeautifulSoup(Tag):
if any(markup.startswith(prefix) for prefix in cant_start_with):
if not space in markup:
if isinstance(markup, bytes):
decoded_markup = markup.decode('utf-8', 'replace')
else:
decoded_markup = markup
warnings.warn(
'"%s" looks like a URL. Beautiful Soup is not an'
' HTTP client. You should probably use an HTTP client like'
' requests to get the document behind the URL, and feed'
' that document to Beautiful Soup.' % decoded_markup
' that document to Beautiful Soup.' % cls._decode_markup(
markup
),
MarkupResemblesLocatorWarning
)
def _feed(self):
"""Internal method that parses previously set markup, creating a large
number of Tag and NavigableString objects.
"""
# Convert the document to Unicode.
self.builder.reset()
@ -368,49 +452,110 @@ class BeautifulSoup(Tag):
self.popTag()
def reset(self):
"""Reset this object to a state as though it had never parsed any
markup.
"""
Tag.__init__(self, self, self.builder, self.ROOT_TAG_NAME)
self.hidden = 1
self.builder.reset()
self.current_data = []
self.currentTag = None
self.tagStack = []
self.open_tag_counter = Counter()
self.preserve_whitespace_tag_stack = []
self.string_container_stack = []
self.pushTag(self)
def new_tag(self, name, namespace=None, nsprefix=None, attrs={}, **kwattrs):
"""Create a new tag associated with this soup."""
def new_tag(self, name, namespace=None, nsprefix=None, attrs={},
sourceline=None, sourcepos=None, **kwattrs):
"""Create a new Tag associated with this BeautifulSoup object.
:param name: The name of the new Tag.
:param namespace: The URI of the new Tag's XML namespace, if any.
:param prefix: The prefix for the new Tag's XML namespace, if any.
:param attrs: A dictionary of this Tag's attribute values; can
be used instead of `kwattrs` for attributes like 'class'
that are reserved words in Python.
:param sourceline: The line number where this tag was
(purportedly) found in its source document.
:param sourcepos: The character position within `sourceline` where this
tag was (purportedly) found.
:param kwattrs: Keyword arguments for the new Tag's attribute values.
"""
kwattrs.update(attrs)
return Tag(None, self.builder, name, namespace, nsprefix, kwattrs)
return self.element_classes.get(Tag, Tag)(
None, self.builder, name, namespace, nsprefix, kwattrs,
sourceline=sourceline, sourcepos=sourcepos
)
def new_string(self, s, subclass=NavigableString):
"""Create a new NavigableString associated with this soup."""
return subclass(s)
def string_container(self, base_class=None):
container = base_class or NavigableString
# There may be a general override of NavigableString.
container = self.element_classes.get(
container, container
)
def insert_before(self, successor):
# On top of that, we may be inside a tag that needs a special
# container class.
if self.string_container_stack and container is NavigableString:
container = self.builder.string_containers.get(
self.string_container_stack[-1].name, container
)
return container
def new_string(self, s, subclass=None):
"""Create a new NavigableString associated with this BeautifulSoup
object.
"""
container = self.string_container(subclass)
return container(s)
def insert_before(self, *args):
"""This method is part of the PageElement API, but `BeautifulSoup` doesn't implement
it because there is nothing before or after it in the parse tree.
"""
raise NotImplementedError("BeautifulSoup objects don't support insert_before().")
def insert_after(self, successor):
def insert_after(self, *args):
"""This method is part of the PageElement API, but `BeautifulSoup` doesn't implement
it because there is nothing before or after it in the parse tree.
"""
raise NotImplementedError("BeautifulSoup objects don't support insert_after().")
def popTag(self):
"""Internal method called by _popToTag when a tag is closed."""
tag = self.tagStack.pop()
if tag.name in self.open_tag_counter:
self.open_tag_counter[tag.name] -= 1
if self.preserve_whitespace_tag_stack and tag == self.preserve_whitespace_tag_stack[-1]:
self.preserve_whitespace_tag_stack.pop()
#print "Pop", tag.name
if self.string_container_stack and tag == self.string_container_stack[-1]:
self.string_container_stack.pop()
#print("Pop", tag.name)
if self.tagStack:
self.currentTag = self.tagStack[-1]
return self.currentTag
def pushTag(self, tag):
#print "Push", tag.name
"""Internal method called by handle_starttag when a tag is opened."""
#print("Push", tag.name)
if self.currentTag is not None:
self.currentTag.contents.append(tag)
self.tagStack.append(tag)
self.currentTag = self.tagStack[-1]
if tag.name != self.ROOT_TAG_NAME:
self.open_tag_counter[tag.name] += 1
if tag.name in self.builder.preserve_whitespace_tags:
self.preserve_whitespace_tag_stack.append(tag)
if tag.name in self.builder.string_containers:
self.string_container_stack.append(tag)
def endData(self, containerClass=NavigableString):
def endData(self, containerClass=None):
"""Method called by the TreeBuilder when the end of a data segment
occurs.
"""
if self.current_data:
current_data = ''.join(self.current_data)
# If whitespace is not preserved, and this string contains
@ -437,11 +582,12 @@ class BeautifulSoup(Tag):
not self.parse_only.search(current_data)):
return
containerClass = self.string_container(containerClass)
o = containerClass(current_data)
self.object_was_parsed(o)
def object_was_parsed(self, o, parent=None, most_recent_element=None):
"""Add an object to the parse tree."""
"""Method called by the TreeBuilder to integrate an object into the parse tree."""
if parent is None:
parent = self.currentTag
if most_recent_element is not None:
@ -510,10 +656,19 @@ class BeautifulSoup(Tag):
def _popToTag(self, name, nsprefix=None, inclusivePop=True):
"""Pops the tag stack up to and including the most recent
instance of the given tag. If inclusivePop is false, pops the tag
stack up to but *not* including the most recent instqance of
the given tag."""
#print "Popping to %s" % name
instance of the given tag.
If there are no open tags with the given name, nothing will be
popped.
:param name: Pop up to the most recent tag with this name.
:param nsprefix: The namespace prefix that goes with `name`.
:param inclusivePop: It this is false, pops the tag stack up
to but *not* including the most recent instqance of the
given tag.
"""
#print("Popping to %s" % name)
if name == self.ROOT_TAG_NAME:
# The BeautifulSoup object itself can never be popped.
return
@ -522,6 +677,8 @@ class BeautifulSoup(Tag):
stack_size = len(self.tagStack)
for i in range(stack_size - 1, 0, -1):
if not self.open_tag_counter.get(name):
break
t = self.tagStack[i]
if (name == t.name and nsprefix == t.prefix):
if inclusivePop:
@ -531,16 +688,24 @@ class BeautifulSoup(Tag):
return most_recently_popped
def handle_starttag(self, name, namespace, nsprefix, attrs):
"""Push a start tag on to the stack.
def handle_starttag(self, name, namespace, nsprefix, attrs, sourceline=None,
sourcepos=None):
"""Called by the tree builder when a new tag is encountered.
If this method returns None, the tag was rejected by the
:param name: Name of the tag.
:param nsprefix: Namespace prefix for the tag.
:param attrs: A dictionary of attribute values.
:param sourceline: The line number where this tag was found in its
source document.
:param sourcepos: The character position within `sourceline` where this
tag was found.
If this method returns None, the tag was rejected by an active
SoupStrainer. You should proceed as if the tag had not occurred
in the document. For instance, if this was a self-closing tag,
don't call handle_endtag.
"""
# print "Start tag %s: %s" % (name, attrs)
# print("Start tag %s: %s" % (name, attrs))
self.endData()
if (self.parse_only and len(self.tagStack) <= 1
@ -548,8 +713,11 @@ class BeautifulSoup(Tag):
or not self.parse_only.search_tag(name, attrs))):
return None
tag = Tag(self, self.builder, name, namespace, nsprefix, attrs,
self.currentTag, self._most_recent_element)
tag = self.element_classes.get(Tag, Tag)(
self, self.builder, name, namespace, nsprefix, attrs,
self.currentTag, self._most_recent_element,
sourceline=sourceline, sourcepos=sourcepos
)
if tag is None:
return tag
if self._most_recent_element is not None:
@ -559,22 +727,38 @@ class BeautifulSoup(Tag):
return tag
def handle_endtag(self, name, nsprefix=None):
#print "End tag: " + name
"""Called by the tree builder when an ending tag is encountered.
:param name: Name of the tag.
:param nsprefix: Namespace prefix for the tag.
"""
#print("End tag: " + name)
self.endData()
self._popToTag(name, nsprefix)
def handle_data(self, data):
"""Called by the tree builder when a chunk of textual data is encountered."""
self.current_data.append(data)
def decode(self, pretty_print=False,
eventual_encoding=DEFAULT_OUTPUT_ENCODING,
formatter="minimal"):
"""Returns a string or Unicode representation of this document.
To get Unicode, pass None for encoding."""
"""Returns a string or Unicode representation of the parse tree
as an HTML or XML document.
:param pretty_print: If this is True, indentation will be used to
make the document more readable.
:param eventual_encoding: The encoding of the final document.
If this is None, the document will be a Unicode string.
"""
if self.is_xml:
# Print the XML declaration
encoding_part = ''
if eventual_encoding in PYTHON_SPECIFIC_ENCODINGS:
# This is a special Python encoding; it can't actually
# go into an XML document because it means nothing
# outside of Python.
eventual_encoding = None
if eventual_encoding != None:
encoding_part = ' encoding="%s"' % eventual_encoding
prefix = '<?xml version="1.0"%s?>\n' % encoding_part
@ -587,7 +771,7 @@ class BeautifulSoup(Tag):
return prefix + super(BeautifulSoup, self).decode(
indent_level, eventual_encoding, formatter)
# Alias to make it easier to type import: 'from bs4 import _soup'
# Aliases to make it easier to get started quickly, e.g. 'from bs4 import _soup'
_s = BeautifulSoup
_soup = BeautifulSoup
@ -603,14 +787,18 @@ class BeautifulStoneSoup(BeautifulSoup):
class StopParsing(Exception):
"""Exception raised by a TreeBuilder if it's unable to continue parsing."""
pass
class FeatureNotFound(ValueError):
"""Exception raised by the BeautifulSoup constructor if no parser with the
requested features is found.
"""
pass
#By default, act as an HTML pretty-printer.
#If this file is run as a script, act as an HTML pretty-printer.
if __name__ == '__main__':
import sys
soup = BeautifulSoup(sys.stdin)
print(soup.prettify())
print((soup.prettify()))

View file

@ -7,8 +7,11 @@ import sys
from bs4.element import (
CharsetMetaAttributeValue,
ContentMetaAttributeValue,
Stylesheet,
Script,
TemplateString,
nonwhitespace_re
)
)
__all__ = [
'HTMLTreeBuilder',
@ -27,18 +30,33 @@ HTML_5 = 'html5'
class TreeBuilderRegistry(object):
"""A way of looking up TreeBuilder subclasses by their name or by desired
features.
"""
def __init__(self):
self.builders_for_feature = defaultdict(list)
self.builders = []
def register(self, treebuilder_class):
"""Register a treebuilder based on its advertised features."""
"""Register a treebuilder based on its advertised features.
:param treebuilder_class: A subclass of Treebuilder. its .features
attribute should list its features.
"""
for feature in treebuilder_class.features:
self.builders_for_feature[feature].insert(0, treebuilder_class)
self.builders.insert(0, treebuilder_class)
def lookup(self, *features):
"""Look up a TreeBuilder subclass with the desired features.
:param features: A list of features to look for. If none are
provided, the most recently registered TreeBuilder subclass
will be used.
:return: A TreeBuilder subclass, or None if there's no
registered subclass with all the requested features.
"""
if len(self.builders) == 0:
# There are no builders at all.
return None
@ -81,7 +99,7 @@ class TreeBuilderRegistry(object):
builder_registry = TreeBuilderRegistry()
class TreeBuilder(object):
"""Turn a document into a Beautiful Soup object tree."""
"""Turn a textual document into a Beautiful Soup object tree."""
NAME = "[Unknown tree builder]"
ALTERNATE_NAMES = []
@ -96,24 +114,53 @@ class TreeBuilder(object):
# comma-separated list of CDATA, rather than a single CDATA.
DEFAULT_CDATA_LIST_ATTRIBUTES = {}
# Whitespace should be preserved inside these tags.
DEFAULT_PRESERVE_WHITESPACE_TAGS = set()
# The textual contents of tags with these names should be
# instantiated with some class other than NavigableString.
DEFAULT_STRING_CONTAINERS = {}
USE_DEFAULT = object()
# Most parsers don't keep track of line numbers.
TRACKS_LINE_NUMBERS = False
def __init__(self, multi_valued_attributes=USE_DEFAULT, preserve_whitespace_tags=USE_DEFAULT):
def __init__(self, multi_valued_attributes=USE_DEFAULT,
preserve_whitespace_tags=USE_DEFAULT,
store_line_numbers=USE_DEFAULT,
string_containers=USE_DEFAULT,
):
"""Constructor.
:param multi_valued_attributes: If this is set to None, the
TreeBuilder will not turn any values for attributes like
'class' into lists. Setting this do a dictionary will
customize this behavior; look at DEFAULT_CDATA_LIST_ATTRIBUTES
for an example.
TreeBuilder will not turn any values for attributes like
'class' into lists. Setting this to a dictionary will
customize this behavior; look at DEFAULT_CDATA_LIST_ATTRIBUTES
for an example.
Internally, these are called "CDATA list attributes", but that
probably doesn't make sense to an end-user, so the argument name
is `multi_valued_attributes`.
Internally, these are called "CDATA list attributes", but that
probably doesn't make sense to an end-user, so the argument name
is `multi_valued_attributes`.
:param preserve_whitespace_tags:
:param preserve_whitespace_tags: A list of tags to treat
the way <pre> tags are treated in HTML. Tags in this list
are immune from pretty-printing; their contents will always be
output as-is.
:param string_containers: A dictionary mapping tag names to
the classes that should be instantiated to contain the textual
contents of those tags. The default is to use NavigableString
for every tag, no matter what the name. You can override the
default by changing DEFAULT_STRING_CONTAINERS.
:param store_line_numbers: If the parser keeps track of the
line numbers and positions of the original markup, that
information will, by default, be stored in each corresponding
`Tag` object. You can turn this off by passing
store_line_numbers=False. If the parser you're using doesn't
keep track of this information, then setting store_line_numbers=True
will do nothing.
"""
self.soup = None
if multi_valued_attributes is self.USE_DEFAULT:
@ -122,14 +169,27 @@ class TreeBuilder(object):
if preserve_whitespace_tags is self.USE_DEFAULT:
preserve_whitespace_tags = self.DEFAULT_PRESERVE_WHITESPACE_TAGS
self.preserve_whitespace_tags = preserve_whitespace_tags
if store_line_numbers == self.USE_DEFAULT:
store_line_numbers = self.TRACKS_LINE_NUMBERS
self.store_line_numbers = store_line_numbers
if string_containers == self.USE_DEFAULT:
string_containers = self.DEFAULT_STRING_CONTAINERS
self.string_containers = string_containers
def initialize_soup(self, soup):
"""The BeautifulSoup object has been initialized and is now
being associated with the TreeBuilder.
:param soup: A BeautifulSoup object.
"""
self.soup = soup
def reset(self):
"""Do any work necessary to reset the underlying parser
for a new document.
By default, this does nothing.
"""
pass
def can_be_empty_element(self, tag_name):
@ -141,24 +201,58 @@ class TreeBuilder(object):
For instance: an HTMLBuilder does not consider a <p> tag to be
an empty-element tag (it's not in
HTMLBuilder.empty_element_tags). This means an empty <p> tag
will be presented as "<p></p>", not "<p />".
will be presented as "<p></p>", not "<p/>" or "<p>".
The default implementation has no opinion about which tags are
empty-element tags, so a tag will be presented as an
empty-element tag if and only if it has no contents.
"<foo></foo>" will become "<foo />", and "<foo>bar</foo>" will
empty-element tag if and only if it has no children.
"<foo></foo>" will become "<foo/>", and "<foo>bar</foo>" will
be left alone.
:param tag_name: The name of a markup tag.
"""
if self.empty_element_tags is None:
return True
return tag_name in self.empty_element_tags
def feed(self, markup):
"""Run some incoming markup through some parsing process,
populating the `BeautifulSoup` object in self.soup.
This method is not implemented in TreeBuilder; it must be
implemented in subclasses.
:return: None.
"""
raise NotImplementedError()
def prepare_markup(self, markup, user_specified_encoding=None,
document_declared_encoding=None):
return markup, None, None, False
document_declared_encoding=None, exclude_encodings=None):
"""Run any preliminary steps necessary to make incoming markup
acceptable to the parser.
:param markup: Some markup -- probably a bytestring.
:param user_specified_encoding: The user asked to try this encoding.
:param document_declared_encoding: The markup itself claims to be
in this encoding. NOTE: This argument is not used by the
calling code and can probably be removed.
:param exclude_encodings: The user asked _not_ to try any of
these encodings.
:yield: A series of 4-tuples:
(markup, encoding, declared encoding,
has undergone character replacement)
Each 4-tuple represents a strategy for converting the
document to Unicode and parsing it. Each strategy will be tried
in turn.
By default, the only strategy is to parse the markup
as-is. See `LXMLTreeBuilderForXML` and
`HTMLParserTreeBuilder` for implementations that take into
account the quirks of particular parsers.
"""
yield markup, None, None, False
def test_fragment_to_document(self, fragment):
"""Wrap an HTML fragment to make it look like a document.
@ -170,16 +264,36 @@ class TreeBuilder(object):
results against other HTML fragments.
This method should not be used outside of tests.
:param fragment: A string -- fragment of HTML.
:return: A string -- a full HTML document.
"""
return fragment
def set_up_substitutions(self, tag):
"""Set up any substitutions that will need to be performed on
a `Tag` when it's output as a string.
By default, this does nothing. See `HTMLTreeBuilder` for a
case where this is used.
:param tag: A `Tag`
:return: Whether or not a substitution was performed.
"""
return False
def _replace_cdata_list_attribute_values(self, tag_name, attrs):
"""Replaces class="foo bar" with class=["foo", "bar"]
"""When an attribute value is associated with a tag that can
have multiple values for that attribute, convert the string
value to a list of strings.
Modifies its input in place.
Basically, replaces class="foo bar" with class=["foo", "bar"]
NOTE: This method modifies its input in place.
:param tag_name: The name of a tag.
:param attrs: A dictionary containing the tag's attributes.
Any appropriate attribute values will be modified in place.
"""
if not attrs:
return attrs
@ -207,7 +321,11 @@ class TreeBuilder(object):
return attrs
class SAXTreeBuilder(TreeBuilder):
"""A Beautiful Soup treebuilder that listens for SAX events."""
"""A Beautiful Soup treebuilder that listens for SAX events.
This is not currently used for anything, but it demonstrates
how a simple TreeBuilder would work.
"""
def feed(self, markup):
raise NotImplementedError()
@ -217,11 +335,11 @@ class SAXTreeBuilder(TreeBuilder):
def startElement(self, name, attrs):
attrs = dict((key[1], value) for key, value in list(attrs.items()))
#print "Start %s, %r" % (name, attrs)
#print("Start %s, %r" % (name, attrs))
self.soup.handle_starttag(name, attrs)
def endElement(self, name):
#print "End %s" % name
#print("End %s" % name)
self.soup.handle_endtag(name)
def startElementNS(self, nsTuple, nodeName, attrs):
@ -271,6 +389,22 @@ class HTMLTreeBuilder(TreeBuilder):
# but it may do so eventually, and this information is available if
# you need to use it.
block_elements = set(["address", "article", "aside", "blockquote", "canvas", "dd", "div", "dl", "dt", "fieldset", "figcaption", "figure", "footer", "form", "h1", "h2", "h3", "h4", "h5", "h6", "header", "hr", "li", "main", "nav", "noscript", "ol", "output", "p", "pre", "section", "table", "tfoot", "ul", "video"])
# The HTML standard defines an unusual content model for these tags.
# We represent this by using a string class other than NavigableString
# inside these tags.
#
# I made this list by going through the HTML spec
# (https://html.spec.whatwg.org/#metadata-content) and looking for
# "metadata content" elements that can contain strings.
#
# TODO: Arguably <noscript> could go here but it seems
# qualitatively different from the other tags.
DEFAULT_STRING_CONTAINERS = {
'style': Stylesheet,
'script': Script,
'template': TemplateString,
}
# The HTML standard defines these attributes as containing a
# space-separated list of values, not a single value. That is,
@ -299,6 +433,16 @@ class HTMLTreeBuilder(TreeBuilder):
DEFAULT_PRESERVE_WHITESPACE_TAGS = set(['pre', 'textarea'])
def set_up_substitutions(self, tag):
"""Replace the declared encoding in a <meta> tag with a placeholder,
to be substituted when the tag is output to a string.
An HTML document may come in to Beautiful Soup as one
encoding, but exit in a different encoding, and the <meta> tag
needs to be changed to reflect this.
:param tag: A `Tag`
:return: Whether or not a substitution was performed.
"""
# We are only interested in <meta> tags
if tag.name != 'meta':
return False
@ -333,8 +477,7 @@ class HTMLTreeBuilder(TreeBuilder):
def register_treebuilders_from(module):
"""Copy TreeBuilders from the given module into this module."""
# I'm fairly sure this is not the best way to do this.
this_module = sys.modules['bs4.builder']
this_module = sys.modules[__name__]
for name in module.__all__:
obj = getattr(module, name)
@ -345,12 +488,22 @@ def register_treebuilders_from(module):
this_module.builder_registry.register(obj)
class ParserRejectedMarkup(Exception):
pass
"""An Exception to be raised when the underlying parser simply
refuses to parse the given markup.
"""
def __init__(self, message_or_exception):
"""Explain why the parser rejected the given markup, either
with a textual explanation or another exception.
"""
if isinstance(message_or_exception, Exception):
e = message_or_exception
message_or_exception = "%s: %s" % (e.__class__.__name__, str(e))
super(ParserRejectedMarkup, self).__init__(message_or_exception)
# Builders are registered in reverse order of priority, so that custom
# builder registrations will take precedence. In general, we want lxml
# to take precedence over html5lib, because it's faster. And we only
# want to use HTMLParser as a last result.
# want to use HTMLParser as a last resort.
from . import _htmlparser
register_treebuilders_from(_htmlparser)
try:

View file

@ -39,12 +39,27 @@ except ImportError as e:
new_html5lib = True
class HTML5TreeBuilder(HTMLTreeBuilder):
"""Use html5lib to build a tree."""
"""Use html5lib to build a tree.
Note that this TreeBuilder does not support some features common
to HTML TreeBuilders. Some of these features could theoretically
be implemented, but at the very least it's quite difficult,
because html5lib moves the parse tree around as it's being built.
* This TreeBuilder doesn't use different subclasses of NavigableString
based on the name of the tag in which the string was found.
* You can't use a SoupStrainer to parse only part of a document.
"""
NAME = "html5lib"
features = [NAME, PERMISSIVE, HTML_5, HTML]
# html5lib can tell us which line number and position in the
# original file is the source of an element.
TRACKS_LINE_NUMBERS = True
def prepare_markup(self, markup, user_specified_encoding,
document_declared_encoding=None, exclude_encodings=None):
# Store the user-specified encoding for use later on.
@ -62,7 +77,7 @@ class HTML5TreeBuilder(HTMLTreeBuilder):
if self.soup.parse_only is not None:
warnings.warn("You provided a value for parse_only, but the html5lib tree builder doesn't support parse_only. The entire document will be parsed.")
parser = html5lib.HTMLParser(tree=self.create_treebuilder)
self.underlying_builder.parser = parser
extra_kwargs = dict()
if not isinstance(markup, str):
if new_html5lib:
@ -70,7 +85,7 @@ class HTML5TreeBuilder(HTMLTreeBuilder):
else:
extra_kwargs['encoding'] = self.user_specified_encoding
doc = parser.parse(markup, **extra_kwargs)
# Set the character encoding detected by the tokenizer.
if isinstance(markup, str):
# We need to special-case this because html5lib sets
@ -84,10 +99,13 @@ class HTML5TreeBuilder(HTMLTreeBuilder):
# with other tree builders.
original_encoding = original_encoding.name
doc.original_encoding = original_encoding
self.underlying_builder.parser = None
def create_treebuilder(self, namespaceHTMLElements):
self.underlying_builder = TreeBuilderForHtml5lib(
namespaceHTMLElements, self.soup)
namespaceHTMLElements, self.soup,
store_line_numbers=self.store_line_numbers
)
return self.underlying_builder
def test_fragment_to_document(self, fragment):
@ -96,15 +114,29 @@ class HTML5TreeBuilder(HTMLTreeBuilder):
class TreeBuilderForHtml5lib(treebuilder_base.TreeBuilder):
def __init__(self, namespaceHTMLElements, soup=None):
def __init__(self, namespaceHTMLElements, soup=None,
store_line_numbers=True, **kwargs):
if soup:
self.soup = soup
else:
from bs4 import BeautifulSoup
self.soup = BeautifulSoup("", "html.parser")
# TODO: Why is the parser 'html.parser' here? To avoid an
# infinite loop?
self.soup = BeautifulSoup(
"", "html.parser", store_line_numbers=store_line_numbers,
**kwargs
)
# TODO: What are **kwargs exactly? Should they be passed in
# here in addition to/instead of being passed to the BeautifulSoup
# constructor?
super(TreeBuilderForHtml5lib, self).__init__(namespaceHTMLElements)
# This will be set later to an html5lib.html5parser.HTMLParser
# object, which we can use to track the current line number.
self.parser = None
self.store_line_numbers = store_line_numbers
def documentClass(self):
self.soup.reset()
return Element(self.soup, self.soup, None)
@ -118,7 +150,16 @@ class TreeBuilderForHtml5lib(treebuilder_base.TreeBuilder):
self.soup.object_was_parsed(doctype)
def elementClass(self, name, namespace):
tag = self.soup.new_tag(name, namespace)
kwargs = {}
if self.parser and self.store_line_numbers:
# This represents the point immediately after the end of the
# tag. We don't know when the tag started, but we do know
# where it ended -- the character just before this one.
sourceline, sourcepos = self.parser.tokenizer.stream.position()
kwargs['sourceline'] = sourceline
kwargs['sourcepos'] = sourcepos-1
tag = self.soup.new_tag(name, namespace, **kwargs)
return Element(tag, self.soup, namespace)
def commentClass(self, data):
@ -126,6 +167,8 @@ class TreeBuilderForHtml5lib(treebuilder_base.TreeBuilder):
def fragmentClass(self):
from bs4 import BeautifulSoup
# TODO: Why is the parser 'html.parser' here? To avoid an
# infinite loop?
self.soup = BeautifulSoup("", "html.parser")
self.soup.name = "[document_fragment]"
return Element(self.soup, self.soup, None)
@ -287,9 +330,7 @@ class Element(treebuilder_base.Node):
return AttrList(self.element)
def setAttributes(self, attributes):
if attributes is not None and len(attributes) > 0:
converted_attributes = []
for name, value in list(attributes.items()):
if isinstance(name, tuple):
@ -334,9 +375,9 @@ class Element(treebuilder_base.Node):
def reparentChildren(self, new_parent):
"""Move all of this tag's children into another tag."""
# print "MOVE", self.element.contents
# print "FROM", self.element
# print "TO", new_parent.element
# print("MOVE", self.element.contents)
# print("FROM", self.element)
# print("TO", new_parent.element)
element = self.element
new_parent_element = new_parent.element
@ -394,9 +435,9 @@ class Element(treebuilder_base.Node):
element.contents = []
element.next_element = final_next_element
# print "DONE WITH MOVE"
# print "FROM", self.element
# print "TO", new_parent_element
# print("DONE WITH MOVE")
# print("FROM", self.element)
# print("TO", new_parent_element)
def cloneNode(self):
tag = self.soup.new_tag(self.element.name, self.namespace)

View file

@ -53,8 +53,30 @@ from bs4.builder import (
HTMLPARSER = 'html.parser'
class BeautifulSoupHTMLParser(HTMLParser):
"""A subclass of the Python standard library's HTMLParser class, which
listens for HTMLParser events and translates them into calls
to Beautiful Soup's tree construction API.
"""
# Strategies for handling duplicate attributes
IGNORE = 'ignore'
REPLACE = 'replace'
def __init__(self, *args, **kwargs):
"""Constructor.
:param on_duplicate_attribute: A strategy for what to do if a
tag includes the same attribute more than once. Accepted
values are: REPLACE (replace earlier values with later
ones, the default), IGNORE (keep the earliest value
encountered), or a callable. A callable must take three
arguments: the dictionary of attributes already processed,
the name of the duplicate attribute, and the most recent value
encountered.
"""
self.on_duplicate_attribute = kwargs.pop(
'on_duplicate_attribute', self.REPLACE
)
HTMLParser.__init__(self, *args, **kwargs)
# Keep a list of empty-element tags that were encountered
@ -67,20 +89,26 @@ class BeautifulSoupHTMLParser(HTMLParser):
self.already_closed_empty_element = []
def error(self, msg):
"""In Python 3, HTMLParser subclasses must implement error(), although this
requirement doesn't appear to be documented.
"""In Python 3, HTMLParser subclasses must implement error(), although
this requirement doesn't appear to be documented.
In Python 2, HTMLParser implements error() as raising an exception.
In Python 2, HTMLParser implements error() by raising an exception,
which we don't want to do.
In any event, this method is called only on very strange markup and our best strategy
is to pretend it didn't happen and keep going.
In any event, this method is called only on very strange
markup and our best strategy is to pretend it didn't happen
and keep going.
"""
warnings.warn(msg)
def handle_startendtag(self, name, attrs):
# This is only called when the markup looks like
# <tag/>.
"""Handle an incoming empty-element tag.
This is only called when the markup looks like <tag/>.
:param name: Name of the tag.
:param attrs: Dictionary of the tag's attributes.
"""
# is_startend() tells handle_starttag not to close the tag
# just because its name matches a known empty-element tag. We
# know that this is an empty-element tag and we want to call
@ -89,6 +117,14 @@ class BeautifulSoupHTMLParser(HTMLParser):
self.handle_endtag(name)
def handle_starttag(self, name, attrs, handle_empty_element=True):
"""Handle an opening tag, e.g. '<tag>'
:param name: Name of the tag.
:param attrs: Dictionary of the tag's attributes.
:param handle_empty_element: True if this tag is known to be
an empty-element tag (i.e. there is not expected to be any
closing tag).
"""
# XXX namespace
attr_dict = {}
for key, value in attrs:
@ -96,10 +132,26 @@ class BeautifulSoupHTMLParser(HTMLParser):
# for consistency with the other tree builders.
if value is None:
value = ''
attr_dict[key] = value
if key in attr_dict:
# A single attribute shows up multiple times in this
# tag. How to handle it depends on the
# on_duplicate_attribute setting.
on_dupe = self.on_duplicate_attribute
if on_dupe == self.IGNORE:
pass
elif on_dupe in (None, self.REPLACE):
attr_dict[key] = value
else:
on_dupe(attr_dict, key, value)
else:
attr_dict[key] = value
attrvalue = '""'
#print "START", name
tag = self.soup.handle_starttag(name, None, None, attr_dict)
#print("START", name)
sourceline, sourcepos = self.getpos()
tag = self.soup.handle_starttag(
name, None, None, attr_dict, sourceline=sourceline,
sourcepos=sourcepos
)
if tag and tag.is_empty_element and handle_empty_element:
# Unlike other parsers, html.parser doesn't send separate end tag
# events for empty-element tags. (It's handled in
@ -117,20 +169,34 @@ class BeautifulSoupHTMLParser(HTMLParser):
self.already_closed_empty_element.append(name)
def handle_endtag(self, name, check_already_closed=True):
#print "END", name
"""Handle a closing tag, e.g. '</tag>'
:param name: A tag name.
:param check_already_closed: True if this tag is expected to
be the closing portion of an empty-element tag,
e.g. '<tag></tag>'.
"""
#print("END", name)
if check_already_closed and name in self.already_closed_empty_element:
# This is a redundant end tag for an empty-element tag.
# We've already called handle_endtag() for it, so just
# check it off the list.
# print "ALREADY CLOSED", name
#print("ALREADY CLOSED", name)
self.already_closed_empty_element.remove(name)
else:
self.soup.handle_endtag(name)
def handle_data(self, data):
"""Handle some textual data that shows up between tags."""
self.soup.handle_data(data)
def handle_charref(self, name):
"""Handle a numeric character reference by converting it to the
corresponding Unicode character and treating it as textual
data.
:param name: Character number, possibly in hexadecimal.
"""
# XXX workaround for a bug in HTMLParser. Remove this once
# it's fixed in all supported versions.
# http://bugs.python.org/issue13633
@ -164,6 +230,12 @@ class BeautifulSoupHTMLParser(HTMLParser):
self.handle_data(data)
def handle_entityref(self, name):
"""Handle a named entity reference by converting it to the
corresponding Unicode character(s) and treating it as textual
data.
:param name: Name of the entity reference.
"""
character = EntitySubstitution.HTML_ENTITY_TO_CHARACTER.get(name)
if character is not None:
data = character
@ -177,21 +249,29 @@ class BeautifulSoupHTMLParser(HTMLParser):
self.handle_data(data)
def handle_comment(self, data):
"""Handle an HTML comment.
:param data: The text of the comment.
"""
self.soup.endData()
self.soup.handle_data(data)
self.soup.endData(Comment)
def handle_decl(self, data):
"""Handle a DOCTYPE declaration.
:param data: The text of the declaration.
"""
self.soup.endData()
if data.startswith("DOCTYPE "):
data = data[len("DOCTYPE "):]
elif data == 'DOCTYPE':
# i.e. "<!DOCTYPE>"
data = ''
data = data[len("DOCTYPE "):]
self.soup.handle_data(data)
self.soup.endData(Doctype)
def unknown_decl(self, data):
"""Handle a declaration of unknown type -- probably a CDATA block.
:param data: The text of the declaration.
"""
if data.upper().startswith('CDATA['):
cls = CData
data = data[len('CDATA['):]
@ -202,47 +282,109 @@ class BeautifulSoupHTMLParser(HTMLParser):
self.soup.endData(cls)
def handle_pi(self, data):
"""Handle a processing instruction.
:param data: The text of the instruction.
"""
self.soup.endData()
self.soup.handle_data(data)
self.soup.endData(ProcessingInstruction)
class HTMLParserTreeBuilder(HTMLTreeBuilder):
"""A Beautiful soup `TreeBuilder` that uses the `HTMLParser` parser,
found in the Python standard library.
"""
is_xml = False
picklable = True
NAME = HTMLPARSER
features = [NAME, HTML, STRICT]
# The html.parser knows which line number and position in the
# original file is the source of an element.
TRACKS_LINE_NUMBERS = True
def __init__(self, parser_args=None, parser_kwargs=None, **kwargs):
"""Constructor.
:param parser_args: Positional arguments to pass into
the BeautifulSoupHTMLParser constructor, once it's
invoked.
:param parser_kwargs: Keyword arguments to pass into
the BeautifulSoupHTMLParser constructor, once it's
invoked.
:param kwargs: Keyword arguments for the superclass constructor.
"""
# Some keyword arguments will be pulled out of kwargs and placed
# into parser_kwargs.
extra_parser_kwargs = dict()
for arg in ('on_duplicate_attribute',):
if arg in kwargs:
value = kwargs.pop(arg)
extra_parser_kwargs[arg] = value
super(HTMLParserTreeBuilder, self).__init__(**kwargs)
parser_args = parser_args or []
parser_kwargs = parser_kwargs or {}
parser_kwargs.update(extra_parser_kwargs)
if CONSTRUCTOR_TAKES_STRICT and not CONSTRUCTOR_STRICT_IS_DEPRECATED:
parser_kwargs['strict'] = False
if CONSTRUCTOR_TAKES_CONVERT_CHARREFS:
parser_kwargs['convert_charrefs'] = False
self.parser_args = (parser_args, parser_kwargs)
def prepare_markup(self, markup, user_specified_encoding=None,
document_declared_encoding=None, exclude_encodings=None):
"""
:return: A 4-tuple (markup, original encoding, encoding
declared within markup, whether any characters had to be
replaced with REPLACEMENT CHARACTER).
"""Run any preliminary steps necessary to make incoming markup
acceptable to the parser.
:param markup: Some markup -- probably a bytestring.
:param user_specified_encoding: The user asked to try this encoding.
:param document_declared_encoding: The markup itself claims to be
in this encoding.
:param exclude_encodings: The user asked _not_ to try any of
these encodings.
:yield: A series of 4-tuples:
(markup, encoding, declared encoding,
has undergone character replacement)
Each 4-tuple represents a strategy for converting the
document to Unicode and parsing it. Each strategy will be tried
in turn.
"""
if isinstance(markup, str):
# Parse Unicode as-is.
yield (markup, None, None, False)
return
# Ask UnicodeDammit to sniff the most likely encoding.
# This was provided by the end-user; treat it as a known
# definite encoding per the algorithm laid out in the HTML5
# spec. (See the EncodingDetector class for details.)
known_definite_encodings = [user_specified_encoding]
# This was found in the document; treat it as a slightly lower-priority
# user encoding.
user_encodings = [document_declared_encoding]
try_encodings = [user_specified_encoding, document_declared_encoding]
dammit = UnicodeDammit(markup, try_encodings, is_html=True,
exclude_encodings=exclude_encodings)
dammit = UnicodeDammit(
markup,
known_definite_encodings=known_definite_encodings,
user_encodings=user_encodings,
is_html=True,
exclude_encodings=exclude_encodings
)
yield (dammit.markup, dammit.original_encoding,
dammit.declared_html_encoding,
dammit.contains_replacement_characters)
def feed(self, markup):
"""Run some incoming markup through some parsing process,
populating the `BeautifulSoup` object in self.soup.
"""
args, kwargs = self.parser_args
parser = BeautifulSoupHTMLParser(*args, **kwargs)
parser.soup = self.soup

View file

@ -57,9 +57,18 @@ class LXMLTreeBuilderForXML(TreeBuilder):
DEFAULT_NSMAPS_INVERTED = _invert(DEFAULT_NSMAPS)
# NOTE: If we parsed Element objects and looked at .sourceline,
# we'd be able to see the line numbers from the original document.
# But instead we build an XMLParser or HTMLParser object to serve
# as the target of parse messages, and those messages don't include
# line numbers.
# See: https://bugs.launchpad.net/lxml/+bug/1846906
def initialize_soup(self, soup):
"""Let the BeautifulSoup object know about the standard namespace
mapping.
:param soup: A `BeautifulSoup`.
"""
super(LXMLTreeBuilderForXML, self).initialize_soup(soup)
self._register_namespaces(self.DEFAULT_NSMAPS)
@ -69,6 +78,8 @@ class LXMLTreeBuilderForXML(TreeBuilder):
while parsing the document.
This might be useful later on when creating CSS selectors.
:param mapping: A dictionary mapping namespace prefixes to URIs.
"""
for key, value in list(mapping.items()):
if key and key not in self.soup._namespaces:
@ -78,20 +89,31 @@ class LXMLTreeBuilderForXML(TreeBuilder):
self.soup._namespaces[key] = value
def default_parser(self, encoding):
# This can either return a parser object or a class, which
# will be instantiated with default arguments.
"""Find the default parser for the given encoding.
:param encoding: A string.
:return: Either a parser object or a class, which
will be instantiated with default arguments.
"""
if self._default_parser is not None:
return self._default_parser
return etree.XMLParser(
target=self, strip_cdata=False, recover=True, encoding=encoding)
def parser_for(self, encoding):
"""Instantiate an appropriate parser for the given encoding.
:param encoding: A string.
:return: A parser object such as an `etree.XMLParser`.
"""
# Use the default parser.
parser = self.default_parser(encoding)
if isinstance(parser, Callable):
# Instantiate the parser with default arguments
parser = parser(target=self, strip_cdata=False, encoding=encoding)
parser = parser(
target=self, strip_cdata=False, recover=True, encoding=encoding
)
return parser
def __init__(self, parser=None, empty_element_tags=None, **kwargs):
@ -116,17 +138,31 @@ class LXMLTreeBuilderForXML(TreeBuilder):
def prepare_markup(self, markup, user_specified_encoding=None,
exclude_encodings=None,
document_declared_encoding=None):
"""
:yield: A series of 4-tuples.
"""Run any preliminary steps necessary to make incoming markup
acceptable to the parser.
lxml really wants to get a bytestring and convert it to
Unicode itself. So instead of using UnicodeDammit to convert
the bytestring to Unicode using different encodings, this
implementation uses EncodingDetector to iterate over the
encodings, and tell lxml to try to parse the document as each
one in turn.
:param markup: Some markup -- hopefully a bytestring.
:param user_specified_encoding: The user asked to try this encoding.
:param document_declared_encoding: The markup itself claims to be
in this encoding.
:param exclude_encodings: The user asked _not_ to try any of
these encodings.
:yield: A series of 4-tuples:
(markup, encoding, declared encoding,
has undergone character replacement)
Each 4-tuple represents a strategy for parsing the document.
Each 4-tuple represents a strategy for converting the
document to Unicode and parsing it. Each strategy will be tried
in turn.
"""
# Instead of using UnicodeDammit to convert the bytestring to
# Unicode using different encodings, use EncodingDetector to
# iterate over the encodings, and tell lxml to try to parse
# the document as each one in turn.
is_html = not self.is_xml
if is_html:
self.processing_instruction_class = ProcessingInstruction
@ -144,9 +180,19 @@ class LXMLTreeBuilderForXML(TreeBuilder):
yield (markup.encode("utf8"), "utf8",
document_declared_encoding, False)
try_encodings = [user_specified_encoding, document_declared_encoding]
# This was provided by the end-user; treat it as a known
# definite encoding per the algorithm laid out in the HTML5
# spec. (See the EncodingDetector class for details.)
known_definite_encodings = [user_specified_encoding]
# This was found in the document; treat it as a slightly lower-priority
# user encoding.
user_encodings = [document_declared_encoding]
detector = EncodingDetector(
markup, try_encodings, is_html, exclude_encodings)
markup, known_definite_encodings=known_definite_encodings,
user_encodings=user_encodings, is_html=is_html,
exclude_encodings=exclude_encodings
)
for encoding in detector.encodings:
yield (detector.markup, encoding, document_declared_encoding, False)
@ -169,7 +215,7 @@ class LXMLTreeBuilderForXML(TreeBuilder):
self.parser.feed(data)
self.parser.close()
except (UnicodeDecodeError, LookupError, etree.ParserError) as e:
raise ParserRejectedMarkup(str(e))
raise ParserRejectedMarkup(e)
def close(self):
self.nsmaps = [self.DEFAULT_NSMAPS_INVERTED]
@ -288,7 +334,7 @@ class LXMLTreeBuilder(HTMLTreeBuilder, LXMLTreeBuilderForXML):
self.parser.feed(markup)
self.parser.close()
except (UnicodeDecodeError, LookupError, etree.ParserError) as e:
raise ParserRejectedMarkup(str(e))
raise ParserRejectedMarkup(e)
def test_fragment_to_document(self, fragment):

File diff suppressed because it is too large Load diff

View file

@ -20,9 +20,13 @@ import sys
import cProfile
def diagnose(data):
"""Diagnostic suite for isolating common problems."""
print("Diagnostic running on Beautiful Soup %s" % __version__)
print("Python version %s" % sys.version)
"""Diagnostic suite for isolating common problems.
:param data: A string containing markup that needs to be explained.
:return: None; diagnostics are printed to standard output.
"""
print(("Diagnostic running on Beautiful Soup %s" % __version__))
print(("Python version %s" % sys.version))
basic_parsers = ["html.parser", "html5lib", "lxml"]
for name in basic_parsers:
@ -39,65 +43,76 @@ def diagnose(data):
basic_parsers.append("lxml-xml")
try:
from lxml import etree
print("Found lxml version %s" % ".".join(map(str,etree.LXML_VERSION)))
print(("Found lxml version %s" % ".".join(map(str,etree.LXML_VERSION))))
except ImportError as e:
print (
print(
"lxml is not installed or couldn't be imported.")
if 'html5lib' in basic_parsers:
try:
import html5lib
print("Found html5lib version %s" % html5lib.__version__)
print(("Found html5lib version %s" % html5lib.__version__))
except ImportError as e:
print (
print(
"html5lib is not installed or couldn't be imported.")
if hasattr(data, 'read'):
data = data.read()
elif data.startswith("http:") or data.startswith("https:"):
print('"%s" looks like a URL. Beautiful Soup is not an HTTP client.' % data)
print(('"%s" looks like a URL. Beautiful Soup is not an HTTP client.' % data))
print("You need to use some other library to get the document behind the URL, and feed that document to Beautiful Soup.")
return
else:
try:
if os.path.exists(data):
print('"%s" looks like a filename. Reading data from the file.' % data)
print(('"%s" looks like a filename. Reading data from the file.' % data))
with open(data) as fp:
data = fp.read()
except ValueError:
# This can happen on some platforms when the 'filename' is
# too long. Assume it's data and not a filename.
pass
print()
print("")
for parser in basic_parsers:
print("Trying to parse your markup with %s" % parser)
print(("Trying to parse your markup with %s" % parser))
success = False
try:
soup = BeautifulSoup(data, features=parser)
success = True
except Exception as e:
print("%s could not parse the markup." % parser)
print(("%s could not parse the markup." % parser))
traceback.print_exc()
if success:
print("Here's what %s did with the markup:" % parser)
print(soup.prettify())
print(("Here's what %s did with the markup:" % parser))
print((soup.prettify()))
print("-" * 80)
print(("-" * 80))
def lxml_trace(data, html=True, **kwargs):
"""Print out the lxml events that occur during parsing.
This lets you see how lxml parses a document when no Beautiful
Soup code is running.
Soup code is running. You can use this to determine whether
an lxml-specific problem is in Beautiful Soup's lxml tree builders
or in lxml itself.
:param data: Some markup.
:param html: If True, markup will be parsed with lxml's HTML parser.
if False, lxml's XML parser will be used.
"""
from lxml import etree
for event, element in etree.iterparse(StringIO(data), html=html, **kwargs):
print(("%s, %4s, %s" % (event, element.tag, element.text)))
class AnnouncingParser(HTMLParser):
"""Announces HTMLParser parse events, without doing anything else."""
"""Subclass of HTMLParser that announces parse events, without doing
anything else.
You can use this to get a picture of how html.parser sees a given
document. The easiest way to do this is to call `htmlparser_trace`.
"""
def _p(self, s):
print(s)
@ -134,6 +149,8 @@ def htmlparser_trace(data):
This lets you see how HTMLParser parses a document when no
Beautiful Soup code is running.
:param data: Some markup.
"""
parser = AnnouncingParser()
parser.feed(data)
@ -154,7 +171,7 @@ def rword(length=5):
def rsentence(length=4):
"Generate a random sentence-like string."
return " ".join(rword(random.randint(4,9)) for i in list(range(length)))
return " ".join(rword(random.randint(4,9)) for i in range(length))
def rdoc(num_elements=1000):
"""Randomly generate an invalid HTML document."""
@ -176,9 +193,9 @@ def rdoc(num_elements=1000):
def benchmark_parsers(num_elements=100000):
"""Very basic head-to-head performance benchmark."""
print("Comparative parser benchmark on Beautiful Soup %s" % __version__)
print(("Comparative parser benchmark on Beautiful Soup %s" % __version__))
data = rdoc(num_elements)
print("Generated a large invalid HTML document (%d bytes)." % len(data))
print(("Generated a large invalid HTML document (%d bytes)." % len(data)))
for parser in ["lxml", ["lxml", "html"], "html5lib", "html.parser"]:
success = False
@ -188,26 +205,26 @@ def benchmark_parsers(num_elements=100000):
b = time.time()
success = True
except Exception as e:
print("%s could not parse the markup." % parser)
print(("%s could not parse the markup." % parser))
traceback.print_exc()
if success:
print("BS4+%s parsed the markup in %.2fs." % (parser, b-a))
print(("BS4+%s parsed the markup in %.2fs." % (parser, b-a)))
from lxml import etree
a = time.time()
etree.HTML(data)
b = time.time()
print("Raw lxml parsed the markup in %.2fs." % (b-a))
print(("Raw lxml parsed the markup in %.2fs." % (b-a)))
import html5lib
parser = html5lib.HTMLParser()
a = time.time()
parser.parse(data)
b = time.time()
print("Raw html5lib parsed the markup in %.2fs." % (b-a))
print(("Raw html5lib parsed the markup in %.2fs." % (b-a)))
def profile(num_elements=100000, parser="lxml"):
"""Use Python's profiler on a randomly generated document."""
filehandle = tempfile.NamedTemporaryFile()
filename = filehandle.name
@ -220,5 +237,6 @@ def profile(num_elements=100000, parser="lxml"):
stats.sort_stats("cumulative")
stats.print_stats('_html5lib|bs4', 50)
# If this file is run as a script, standard input is diagnosed.
if __name__ == '__main__':
diagnose(sys.stdin.read())

File diff suppressed because it is too large Load diff

View file

@ -5,6 +5,28 @@ class Formatter(EntitySubstitution):
Some parts of this strategy come from the distinction between
HTML4, HTML5, and XML. Others are configurable by the user.
Formatters are passed in as the `formatter` argument to methods
like `PageElement.encode`. Most people won't need to think about
formatters, and most people who need to think about them can pass
in one of these predefined strings as `formatter` rather than
making a new Formatter object:
For HTML documents:
* 'html' - HTML entity substitution for generic HTML documents. (default)
* 'html5' - HTML entity substitution for HTML5 documents, as
well as some optimizations in the way tags are rendered.
* 'minimal' - Only make the substitutions necessary to guarantee
valid HTML.
* None - Do not perform any substitution. This will be faster
but may result in invalid markup.
For XML documents:
* 'html' - Entity substitution for XHTML documents.
* 'minimal' - Only make the substitutions necessary to guarantee
valid XML. (default)
* None - Do not perform any substitution. This will be faster
but may result in invalid markup.
"""
# Registries of XML and HTML formatters.
XML_FORMATTERS = {}
@ -27,11 +49,26 @@ class Formatter(EntitySubstitution):
def __init__(
self, language=None, entity_substitution=None,
void_element_close_prefix='/', cdata_containing_tags=None,
empty_attributes_are_booleans=False,
):
"""
"""Constructor.
:param void_element_close_prefix: By default, represent void
elements as <tag/> rather than <tag>
:param language: This should be Formatter.XML if you are formatting
XML markup and Formatter.HTML if you are formatting HTML markup.
:param entity_substitution: A function to call to replace special
characters with XML/HTML entities. For examples, see
bs4.dammit.EntitySubstitution.substitute_html and substitute_xml.
:param void_element_close_prefix: By default, void elements
are represented as <tag/> (XML rules) rather than <tag>
(HTML rules). To get <tag>, pass in the empty string.
:param cdata_containing_tags: The list of tags that are defined
as containing CDATA in this dialect. For example, in HTML,
<script> and <style> tags are defined as containing CDATA,
and their contents should not be formatted.
:param blank_attributes_are_booleans: Render attributes whose value
is the empty string as HTML-style boolean attributes.
(Attributes whose value is None are always rendered this way.)
"""
self.language = language
self.entity_substitution = entity_substitution
@ -39,9 +76,17 @@ class Formatter(EntitySubstitution):
self.cdata_containing_tags = self._default(
language, cdata_containing_tags, 'cdata_containing_tags'
)
self.empty_attributes_are_booleans=empty_attributes_are_booleans
def substitute(self, ns):
"""Process a string that needs to undergo entity substitution."""
"""Process a string that needs to undergo entity substitution.
This may be a string encountered in an attribute value or as
text.
:param ns: A string.
:return: A string with certain characters replaced by named
or numeric entities.
"""
if not self.entity_substitution:
return ns
from .element import NavigableString
@ -54,21 +99,41 @@ class Formatter(EntitySubstitution):
return self.entity_substitution(ns)
def attribute_value(self, value):
"""Process the value of an attribute."""
"""Process the value of an attribute.
:param ns: A string.
:return: A string with certain characters replaced by named
or numeric entities.
"""
return self.substitute(value)
def attributes(self, tag):
"""Reorder a tag's attributes however you want."""
return sorted(tag.attrs.items())
"""Reorder a tag's attributes however you want.
By default, attributes are sorted alphabetically. This makes
behavior consistent between Python 2 and Python 3, and preserves
backwards compatibility with older versions of Beautiful Soup.
If `empty_boolean_attributes` is True, then attributes whose
values are set to the empty string will be treated as boolean
attributes.
"""
if tag.attrs is None:
return []
return sorted(
(k, (None if self.empty_attributes_are_booleans and v == '' else v))
for k, v in list(tag.attrs.items())
)
class HTMLFormatter(Formatter):
"""A generic Formatter for HTML."""
REGISTRY = {}
def __init__(self, *args, **kwargs):
return super(HTMLFormatter, self).__init__(self.HTML, *args, **kwargs)
class XMLFormatter(Formatter):
"""A generic Formatter for XML."""
REGISTRY = {}
def __init__(self, *args, **kwargs):
return super(XMLFormatter, self).__init__(self.XML, *args, **kwargs)
@ -80,7 +145,8 @@ HTMLFormatter.REGISTRY['html'] = HTMLFormatter(
)
HTMLFormatter.REGISTRY["html5"] = HTMLFormatter(
entity_substitution=EntitySubstitution.substitute_html,
void_element_close_prefix = None
void_element_close_prefix=None,
empty_attributes_are_booleans=True,
)
HTMLFormatter.REGISTRY["minimal"] = HTMLFormatter(
entity_substitution=EntitySubstitution.substitute_xml

View file

@ -8,6 +8,7 @@ import pickle
import copy
import functools
import unittest
import warnings
from unittest import TestCase
from bs4 import BeautifulSoup
from bs4.element import (
@ -15,7 +16,10 @@ from bs4.element import (
Comment,
ContentMetaAttributeValue,
Doctype,
PYTHON_SPECIFIC_ENCODINGS,
SoupStrainer,
Script,
Stylesheet,
Tag
)
@ -83,8 +87,22 @@ class SoupTest(unittest.TestCase):
if compare_parsed_to is None:
compare_parsed_to = to_parse
# Verify that the documents come out the same.
self.assertEqual(obj.decode(), self.document_for(compare_parsed_to))
# Also run some checks on the BeautifulSoup object itself:
# Verify that every tag that was opened was eventually closed.
# There are no tags in the open tag counter.
assert all(v==0 for v in list(obj.open_tag_counter.values()))
# The only tag in the tag stack is the one for the root
# document.
self.assertEqual(
[obj.ROOT_TAG_NAME], [x.name for x in obj.tagStack]
)
def assertConnectedness(self, element):
"""Ensure that next_element and previous_element are properly
set for all descendants of the given element.
@ -211,7 +229,41 @@ class SoupTest(unittest.TestCase):
return child
class HTMLTreeBuilderSmokeTest(object):
class TreeBuilderSmokeTest(object):
# Tests that are common to HTML and XML tree builders.
def test_fuzzed_input(self):
# This test centralizes in one place the various fuzz tests
# for Beautiful Soup created by the oss-fuzz project.
# These strings superficially resemble markup, but they
# generally can't be parsed into anything. The best we can
# hope for is that parsing these strings won't crash the
# parser.
#
# n.b. This markup is commented out because these fuzz tests
# _do_ crash the parser. However the crashes are due to bugs
# in html.parser, not Beautiful Soup -- otherwise I'd fix the
# bugs!
bad_markup = [
# https://bugs.chromium.org/p/oss-fuzz/issues/detail?id=28873
# https://github.com/guidovranken/python-library-fuzzers/blob/master/corp-html/519e5b4269a01185a0d5e76295251921da2f0700
# https://bugs.python.org/issue37747
#
#b'\n<![\xff\xfe\xfe\xcd\x00',
#https://github.com/guidovranken/python-library-fuzzers/blob/master/corp-html/de32aa55785be29bbc72a1a8e06b00611fb3d9f8
# https://bugs.python.org/issue34480
#
#b'<![n\x00'
]
for markup in bad_markup:
with warnings.catch_warnings(record=False):
soup = self.soup(markup)
class HTMLTreeBuilderSmokeTest(TreeBuilderSmokeTest):
"""A basic test of a treebuilder's competence.
@ -233,6 +285,22 @@ class HTMLTreeBuilderSmokeTest(object):
new_tag = soup.new_tag(name)
self.assertEqual(True, new_tag.is_empty_element)
def test_special_string_containers(self):
soup = self.soup(
"<style>Some CSS</style><script>Some Javascript</script>"
)
assert isinstance(soup.style.string, Stylesheet)
assert isinstance(soup.script.string, Script)
soup = self.soup(
"<style><!--Some CSS--></style>"
)
assert isinstance(soup.style.string, Stylesheet)
# The contents of the style tag resemble an HTML comment, but
# it's not treated as a comment.
self.assertEqual("<!--Some CSS-->", soup.style.string)
assert isinstance(soup.style.string, Stylesheet)
def test_pickle_and_unpickle_identity(self):
# Pickling a tree, then unpickling it, yields a tree identical
# to the original.
@ -250,18 +318,21 @@ class HTMLTreeBuilderSmokeTest(object):
doctype = soup.contents[0]
self.assertEqual(doctype.__class__, Doctype)
self.assertEqual(doctype, doctype_fragment)
self.assertEqual(str(soup)[:len(doctype_str)], doctype_str)
self.assertEqual(
soup.encode("utf8")[:len(doctype_str)],
doctype_str
)
# Make sure that the doctype was correctly associated with the
# parse tree and that the rest of the document parsed.
self.assertEqual(soup.p.contents[0], 'foo')
def _document_with_doctype(self, doctype_fragment):
def _document_with_doctype(self, doctype_fragment, doctype_string="DOCTYPE"):
"""Generate and parse a document with the given doctype."""
doctype = '<!DOCTYPE %s>' % doctype_fragment
doctype = '<!%s %s>' % (doctype_string, doctype_fragment)
markup = doctype + '\n<p>foo</p>'
soup = self.soup(markup)
return doctype, soup
return doctype.encode("utf8"), soup
def test_normal_doctypes(self):
"""Make sure normal, everyday HTML doctypes are handled correctly."""
@ -274,6 +345,27 @@ class HTMLTreeBuilderSmokeTest(object):
doctype = soup.contents[0]
self.assertEqual("", doctype.strip())
def test_mixed_case_doctype(self):
# A lowercase or mixed-case doctype becomes a Doctype.
for doctype_fragment in ("doctype", "DocType"):
doctype_str, soup = self._document_with_doctype(
"html", doctype_fragment
)
# Make sure a Doctype object was created and that the DOCTYPE
# is uppercase.
doctype = soup.contents[0]
self.assertEqual(doctype.__class__, Doctype)
self.assertEqual(doctype, "html")
self.assertEqual(
soup.encode("utf8")[:len(doctype_str)],
b"<!DOCTYPE html>"
)
# Make sure that the doctype was correctly associated with the
# parse tree and that the rest of the document parsed.
self.assertEqual(soup.p.contents[0], 'foo')
def test_public_doctype_with_url(self):
doctype = 'html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN" "http://www.w3.org/TR/xhtml1/DTD/xhtml1-transitional.dtd"'
self.assertDoctypeHandled(doctype)
@ -532,7 +624,7 @@ Hello, world!
self.assertSoupEquals("&#10000000000000;", expect)
self.assertSoupEquals("&#x10000000000000;", expect)
self.assertSoupEquals("&#1000000000;", expect)
def test_multipart_strings(self):
"Mostly to prevent a recurrence of a bug in the html5lib treebuilder."
soup = self.soup("<html><h2>\nfoo</h2><p></p></html>")
@ -594,7 +686,7 @@ Hello, world!
markup = b'<a class="foo bar">'
soup = self.soup(markup)
self.assertEqual(['foo', 'bar'], soup.a['class'])
#
# Generally speaking, tests below this point are more tests of
# Beautiful Soup than tests of the tree builders. But parsers are
@ -779,11 +871,44 @@ Hello, world!
# encoding.
self.assertEqual('utf8', charset.encode("utf8"))
def test_python_specific_encodings_not_used_in_charset(self):
# You can encode an HTML document using a Python-specific
# encoding, but that encoding won't be mentioned _inside_ the
# resulting document. Instead, the document will appear to
# have no encoding.
for markup in [
b'<meta charset="utf8"></head>'
b'<meta id="encoding" charset="utf-8" />'
]:
soup = self.soup(markup)
for encoding in PYTHON_SPECIFIC_ENCODINGS:
if encoding in (
'idna', 'mbcs', 'oem', 'undefined',
'string_escape', 'string-escape'
):
# For one reason or another, these will raise an
# exception if we actually try to use them, so don't
# bother.
continue
encoded = soup.encode(encoding)
assert b'meta charset=""' in encoded
assert encoding.encode("ascii") not in encoded
def test_tag_with_no_attributes_can_have_attributes_added(self):
data = self.soup("<a>text</a>")
data.a['foo'] = 'bar'
self.assertEqual('<a foo="bar">text</a>', data.a.decode())
def test_closing_tag_with_no_opening_tag(self):
# Without BeautifulSoup.open_tag_counter, the </span> tag will
# cause _popToTag to be called over and over again as we look
# for a <span> tag that wasn't there. The result is that 'text2'
# will show up outside the body of the document.
soup = self.soup("<body><div><p>text1</p></span>text2</div></body>")
self.assertEqual(
"<body><div><p>text1</p>text2</div></body>", soup.body.decode()
)
def test_worst_case(self):
"""Test the worst case (currently) for linking issues."""
@ -791,7 +916,7 @@ Hello, world!
self.linkage_validator(soup)
class XMLTreeBuilderSmokeTest(object):
class XMLTreeBuilderSmokeTest(TreeBuilderSmokeTest):
def test_pickle_and_unpickle_identity(self):
# Pickling a tree, then unpickling it, yields a tree identical
@ -812,6 +937,25 @@ class XMLTreeBuilderSmokeTest(object):
soup = self.soup(markup)
self.assertEqual(markup, soup.encode("utf8"))
def test_python_specific_encodings_not_used_in_xml_declaration(self):
# You can encode an XML document using a Python-specific
# encoding, but that encoding won't be mentioned _inside_ the
# resulting document.
markup = b"""<?xml version="1.0"?>\n<foo/>"""
soup = self.soup(markup)
for encoding in PYTHON_SPECIFIC_ENCODINGS:
if encoding in (
'idna', 'mbcs', 'oem', 'undefined',
'string_escape', 'string-escape'
):
# For one reason or another, these will raise an
# exception if we actually try to use them, so don't
# bother.
continue
encoded = soup.encode(encoding)
assert b'<?xml version="1.0"?>' in encoded
assert encoding.encode("ascii") not in encoded
def test_processing_instruction(self):
markup = b"""<?xml version="1.0" encoding="utf8"?>\n<?PITarget PIContent?>"""
soup = self.soup(markup)
@ -828,7 +972,7 @@ class XMLTreeBuilderSmokeTest(object):
soup = self.soup(markup)
self.assertEqual(
soup.encode("utf-8"), markup)
def test_nested_namespaces(self):
doc = b"""<?xml version="1.0" encoding="utf-8"?>
<!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.1//EN" "http://www.w3.org/TR/xhtml11/DTD/xhtml11.dtd">

View file

@ -168,3 +168,59 @@ class HTML5LibBuilderSmokeTest(SoupTest, HTML5TreeBuilderSmokeTest):
for form in soup.find_all('form'):
inputs.extend(form.find_all('input'))
self.assertEqual(len(inputs), 1)
def test_tracking_line_numbers(self):
# The html.parser TreeBuilder keeps track of line number and
# position of each element.
markup = "\n <p>\n\n<sourceline>\n<b>text</b></sourceline><sourcepos></p>"
soup = self.soup(markup)
self.assertEqual(2, soup.p.sourceline)
self.assertEqual(5, soup.p.sourcepos)
self.assertEqual("sourceline", soup.p.find('sourceline').name)
# You can deactivate this behavior.
soup = self.soup(markup, store_line_numbers=False)
self.assertEqual("sourceline", soup.p.sourceline.name)
self.assertEqual("sourcepos", soup.p.sourcepos.name)
def test_special_string_containers(self):
# The html5lib tree builder doesn't support this standard feature,
# because there's no way of knowing, when a string is created,
# where in the tree it will eventually end up.
pass
def test_html5_attributes(self):
# The html5lib TreeBuilder can convert any entity named in
# the HTML5 spec to a sequence of Unicode characters, and
# convert those Unicode characters to a (potentially
# different) named entity on the way out.
#
# This is a copy of the same test from
# HTMLParserTreeBuilderSmokeTest. It's not in the superclass
# because the lxml HTML TreeBuilder _doesn't_ work this way.
for input_element, output_unicode, output_element in (
("&RightArrowLeftArrow;", '\u21c4', b'&rlarr;'),
('&models;', '\u22a7', b'&models;'),
('&Nfr;', '\U0001d511', b'&Nfr;'),
('&ngeqq;', '\u2267\u0338', b'&ngeqq;'),
('&not;', '\xac', b'&not;'),
('&Not;', '\u2aec', b'&Not;'),
('&quot;', '"', b'"'),
('&there4;', '\u2234', b'&there4;'),
('&Therefore;', '\u2234', b'&there4;'),
('&therefore;', '\u2234', b'&there4;'),
("&fjlig;", 'fj', b'fj'),
("&sqcup;", '\u2294', b'&sqcup;'),
("&sqcups;", '\u2294\ufe00', b'&sqcups;'),
("&apos;", "'", b"'"),
("&verbar;", "|", b"|"),
):
markup = '<div>%s</div>' % input_element
div = self.soup(markup).div
without_element = div.encode()
expect = b"<div>%s</div>" % output_unicode.encode("utf8")
self.assertEqual(without_element, expect)
with_element = div.encode(formatter="html")
expect = b"<div>%s</div>" % output_element
self.assertEqual(with_element, expect)

View file

@ -3,6 +3,7 @@ trees."""
from pdb import set_trace
import pickle
import warnings
from bs4.testing import SoupTest, HTMLTreeBuilderSmokeTest
from bs4.builder import HTMLParserTreeBuilder
from bs4.builder._htmlparser import BeautifulSoupHTMLParser
@ -37,6 +38,88 @@ class HTMLParserTreeBuilderSmokeTest(SoupTest, HTMLTreeBuilderSmokeTest):
# finishes working is handled.
self.assertSoupEquals("foo &# bar", "foo &amp;# bar")
def test_tracking_line_numbers(self):
# The html.parser TreeBuilder keeps track of line number and
# position of each element.
markup = "\n <p>\n\n<sourceline>\n<b>text</b></sourceline><sourcepos></p>"
soup = self.soup(markup)
self.assertEqual(2, soup.p.sourceline)
self.assertEqual(3, soup.p.sourcepos)
self.assertEqual("sourceline", soup.p.find('sourceline').name)
# You can deactivate this behavior.
soup = self.soup(markup, store_line_numbers=False)
self.assertEqual("sourceline", soup.p.sourceline.name)
self.assertEqual("sourcepos", soup.p.sourcepos.name)
def test_on_duplicate_attribute(self):
# The html.parser tree builder has a variety of ways of
# handling a tag that contains the same attribute multiple times.
markup = '<a class="cls" href="url1" href="url2" href="url3" id="id">'
# If you don't provide any particular value for
# on_duplicate_attribute, later values replace earlier values.
soup = self.soup(markup)
self.assertEqual("url3", soup.a['href'])
self.assertEqual(["cls"], soup.a['class'])
self.assertEqual("id", soup.a['id'])
# You can also get this behavior explicitly.
def assert_attribute(on_duplicate_attribute, expected):
soup = self.soup(
markup, on_duplicate_attribute=on_duplicate_attribute
)
self.assertEqual(expected, soup.a['href'])
# Verify that non-duplicate attributes are treated normally.
self.assertEqual(["cls"], soup.a['class'])
self.assertEqual("id", soup.a['id'])
assert_attribute(None, "url3")
assert_attribute(BeautifulSoupHTMLParser.REPLACE, "url3")
# You can ignore subsequent values in favor of the first.
assert_attribute(BeautifulSoupHTMLParser.IGNORE, "url1")
# And you can pass in a callable that does whatever you want.
def accumulate(attrs, key, value):
if not isinstance(attrs[key], list):
attrs[key] = [attrs[key]]
attrs[key].append(value)
assert_attribute(accumulate, ["url1", "url2", "url3"])
def test_html5_attributes(self):
# The html.parser TreeBuilder can convert any entity named in
# the HTML5 spec to a sequence of Unicode characters, and
# convert those Unicode characters to a (potentially
# different) named entity on the way out.
for input_element, output_unicode, output_element in (
("&RightArrowLeftArrow;", '\u21c4', b'&rlarr;'),
('&models;', '\u22a7', b'&models;'),
('&Nfr;', '\U0001d511', b'&Nfr;'),
('&ngeqq;', '\u2267\u0338', b'&ngeqq;'),
('&not;', '\xac', b'&not;'),
('&Not;', '\u2aec', b'&Not;'),
('&quot;', '"', b'"'),
('&there4;', '\u2234', b'&there4;'),
('&Therefore;', '\u2234', b'&there4;'),
('&therefore;', '\u2234', b'&there4;'),
("&fjlig;", 'fj', b'fj'),
("&sqcup;", '\u2294', b'&sqcup;'),
("&sqcups;", '\u2294\ufe00', b'&sqcups;'),
("&apos;", "'", b"'"),
("&verbar;", "|", b"|"),
):
markup = '<div>%s</div>' % input_element
div = self.soup(markup).div
without_element = div.encode()
expect = b"<div>%s</div>" % output_unicode.encode("utf8")
self.assertEqual(without_element, expect)
with_element = div.encode(formatter="html")
expect = b"<div>%s</div>" % output_element
self.assertEqual(with_element, expect)
class TestHTMLParserSubclass(SoupTest):
def test_error(self):
@ -44,4 +127,8 @@ class TestHTMLParserSubclass(SoupTest):
that doesn't cause a crash.
"""
parser = BeautifulSoupHTMLParser()
parser.error("don't crash")
with warnings.catch_warnings(record=True) as warns:
parser.error("don't crash")
[warning] = warns
assert "don't crash" == str(warning.message)

View file

@ -45,7 +45,7 @@ class LXMLTreeBuilderSmokeTest(SoupTest, HTMLTreeBuilderSmokeTest):
"<p>foo&#x10000000000000;bar</p>", "<p>foobar</p>")
self.assertSoupEquals(
"<p>foo&#1000000000;bar</p>", "<p>foobar</p>")
def test_entities_in_foreign_document_encoding(self):
# We can't implement this case correctly because by the time we
# hear about markup like "&#147;", it's been (incorrectly) converted into
@ -71,6 +71,21 @@ class LXMLTreeBuilderSmokeTest(SoupTest, HTMLTreeBuilderSmokeTest):
self.assertEqual("<b/>", str(soup.b))
self.assertTrue("BeautifulStoneSoup class is deprecated" in str(w[0].message))
def test_tracking_line_numbers(self):
# The lxml TreeBuilder cannot keep track of line numbers from
# the original markup. Even if you ask for line numbers, we
# don't have 'em.
#
# This means that if you have a tag like <sourceline> or
# <sourcepos>, attribute access will find it rather than
# giving you a numeric answer.
soup = self.soup(
"\n <p>\n\n<sourceline>\n<b>text</b></sourceline><sourcepos></p>",
store_line_numbers=True
)
self.assertEqual("sourceline", soup.p.sourceline.name)
self.assertEqual("sourcepos", soup.p.sourcepos.name)
@skipIf(
not LXML_PRESENT,
"lxml seems not to be present, not testing its XML tree builder.")

View file

@ -3,6 +3,7 @@
from pdb import set_trace
import logging
import os
import unittest
import sys
import tempfile
@ -10,18 +11,27 @@ import tempfile
from bs4 import (
BeautifulSoup,
BeautifulStoneSoup,
GuessedAtParserWarning,
MarkupResemblesLocatorWarning,
)
from bs4.builder import (
TreeBuilder,
ParserRejectedMarkup,
)
from bs4.element import (
CharsetMetaAttributeValue,
Comment,
ContentMetaAttributeValue,
SoupStrainer,
NamespacedAttribute,
Tag,
NavigableString,
)
import bs4.dammit
from bs4.dammit import (
EntitySubstitution,
UnicodeDammit,
EncodingDetector,
)
from bs4.testing import (
default_builder,
@ -62,10 +72,21 @@ class TestConstructor(SoupTest):
def __init__(self, **kwargs):
self.called_with = kwargs
self.is_xml = True
self.store_line_numbers = False
self.cdata_list_attributes = []
self.preserve_whitespace_tags = []
self.string_containers = {}
def initialize_soup(self, soup):
pass
def feed(self, markup):
self.fed = markup
def reset(self):
pass
def ignore(self, ignore):
pass
set_up_substitutions = can_be_empty_element = ignore
def prepare_markup(self, *args, **kwargs):
return ''
yield "prepared markup", "original encoding", "declared encoding", "contains replacement characters"
kwargs = dict(
var="value",
@ -77,7 +98,8 @@ class TestConstructor(SoupTest):
soup = BeautifulSoup('', builder=Mock, **kwargs)
assert isinstance(soup.builder, Mock)
self.assertEqual(dict(var="value"), soup.builder.called_with)
self.assertEqual("prepared markup", soup.builder.fed)
# You can also instantiate the TreeBuilder yourself. In this
# case, that specific object is used and any keyword arguments
# to the BeautifulSoup constructor are ignored.
@ -91,6 +113,26 @@ class TestConstructor(SoupTest):
self.assertEqual(builder, soup.builder)
self.assertEqual(kwargs, builder.called_with)
def test_parser_markup_rejection(self):
# If markup is completely rejected by the parser, an
# explanatory ParserRejectedMarkup exception is raised.
class Mock(TreeBuilder):
def feed(self, *args, **kwargs):
raise ParserRejectedMarkup("Nope.")
def prepare_markup(self, *args, **kwargs):
# We're going to try two different ways of preparing this markup,
# but feed() will reject both of them.
yield markup, None, None, False
yield markup, None, None, False
import re
self.assertRaisesRegex(
ParserRejectedMarkup,
"The markup you provided was rejected by the parser. Trying a different parser or a different encoding may help.",
BeautifulSoup, '', builder=Mock,
)
def test_cdata_list_attributes(self):
# Most attribute values are represented as scalars, but the
# HTML standard says that some attributes, like 'class' have
@ -120,28 +162,96 @@ class TestConstructor(SoupTest):
self.assertEqual(["an", "id"], a['id'])
self.assertEqual(" a class ", a['class'])
def test_replacement_classes(self):
# Test the ability to pass in replacements for element classes
# which will be used when building the tree.
class TagPlus(Tag):
pass
class StringPlus(NavigableString):
pass
class CommentPlus(Comment):
pass
soup = self.soup(
"<a><b>foo</b>bar</a><!--whee-->",
element_classes = {
Tag: TagPlus,
NavigableString: StringPlus,
Comment: CommentPlus,
}
)
# The tree was built with TagPlus, StringPlus, and CommentPlus objects,
# rather than Tag, String, and Comment objects.
assert all(
isinstance(x, (TagPlus, StringPlus, CommentPlus))
for x in soup.recursiveChildGenerator()
)
def test_alternate_string_containers(self):
# Test the ability to customize the string containers for
# different types of tags.
class PString(NavigableString):
pass
class BString(NavigableString):
pass
soup = self.soup(
"<div>Hello.<p>Here is <b>some <i>bolded</i></b> text",
string_containers = {
'b': BString,
'p': PString,
}
)
# The string before the <p> tag is a regular NavigableString.
assert isinstance(soup.div.contents[0], NavigableString)
# The string inside the <p> tag, but not inside the <i> tag,
# is a PString.
assert isinstance(soup.p.contents[0], PString)
# Every string inside the <b> tag is a BString, even the one that
# was also inside an <i> tag.
for s in soup.b.strings:
assert isinstance(s, BString)
# Now that parsing was complete, the string_container_stack
# (where this information was kept) has been cleared out.
self.assertEqual([], soup.string_container_stack)
class TestWarnings(SoupTest):
def _no_parser_specified(self, s, is_there=True):
v = s.startswith(BeautifulSoup.NO_PARSER_SPECIFIED_WARNING[:80])
self.assertTrue(v)
def _assert_warning(self, warnings, cls):
for w in warnings:
if isinstance(w.message, cls):
return w
raise Exception("%s warning not found in %r" % cls, warnings)
def _assert_no_parser_specified(self, w):
warning = self._assert_warning(w, GuessedAtParserWarning)
message = str(warning.message)
self.assertTrue(
message.startswith(BeautifulSoup.NO_PARSER_SPECIFIED_WARNING[:60])
)
def test_warning_if_no_parser_specified(self):
with warnings.catch_warnings(record=True) as w:
soup = self.soup("<a><b></b></a>")
msg = str(w[0].message)
self._assert_no_parser_specified(msg)
soup = BeautifulSoup("<a><b></b></a>")
self._assert_no_parser_specified(w)
def test_warning_if_parser_specified_too_vague(self):
with warnings.catch_warnings(record=True) as w:
soup = self.soup("<a><b></b></a>", "html")
msg = str(w[0].message)
self._assert_no_parser_specified(msg)
soup = BeautifulSoup("<a><b></b></a>", "html")
self._assert_no_parser_specified(w)
def test_no_warning_if_explicit_parser_specified(self):
with warnings.catch_warnings(record=True) as w:
soup = self.soup("<a><b></b></a>", "html.parser")
soup = BeautifulSoup("<a><b></b></a>", "html.parser")
self.assertEqual([], w)
def test_parseOnlyThese_renamed_to_parse_only(self):
@ -165,41 +275,58 @@ class TestWarnings(SoupTest):
self.assertRaises(
TypeError, self.soup, "<a>", no_such_argument=True)
class TestWarnings(SoupTest):
def test_disk_file_warning(self):
filehandle = tempfile.NamedTemporaryFile()
filename = filehandle.name
try:
with warnings.catch_warnings(record=True) as w:
soup = self.soup(filename)
msg = str(w[0].message)
self.assertTrue("looks like a filename" in msg)
warning = self._assert_warning(w, MarkupResemblesLocatorWarning)
self.assertTrue("looks like a filename" in str(warning.message))
finally:
filehandle.close()
# The file no longer exists, so Beautiful Soup will no longer issue the warning.
with warnings.catch_warnings(record=True) as w:
soup = self.soup(filename)
self.assertEqual(0, len(w))
self.assertEqual([], w)
def test_directory_warning(self):
try:
filename = tempfile.mkdtemp()
with warnings.catch_warnings(record=True) as w:
soup = self.soup(filename)
warning = self._assert_warning(w, MarkupResemblesLocatorWarning)
self.assertTrue("looks like a directory" in str(warning.message))
finally:
os.rmdir(filename)
# The directory no longer exists, so Beautiful Soup will no longer issue the warning.
with warnings.catch_warnings(record=True) as w:
soup = self.soup(filename)
self.assertEqual([], w)
def test_url_warning_with_bytes_url(self):
with warnings.catch_warnings(record=True) as warning_list:
soup = self.soup(b"http://www.crummybytes.com/")
# Be aware this isn't the only warning that can be raised during
# execution..
self.assertTrue(any("looks like a URL" in str(w.message)
for w in warning_list))
warning = self._assert_warning(
warning_list, MarkupResemblesLocatorWarning
)
self.assertTrue("looks like a URL" in str(warning.message))
def test_url_warning_with_unicode_url(self):
with warnings.catch_warnings(record=True) as warning_list:
# note - this url must differ from the bytes one otherwise
# python's warnings system swallows the second warning
soup = self.soup("http://www.crummyunicode.com/")
self.assertTrue(any("looks like a URL" in str(w.message)
for w in warning_list))
warning = self._assert_warning(
warning_list, MarkupResemblesLocatorWarning
)
self.assertTrue("looks like a URL" in str(warning.message))
def test_url_warning_with_bytes_and_space(self):
# Here the markup contains something besides a URL, so no warning
# is issued.
with warnings.catch_warnings(record=True) as warning_list:
soup = self.soup(b"http://www.crummybytes.com/ is great")
self.assertFalse(any("looks like a URL" in str(w.message)
@ -241,6 +368,51 @@ class TestEntitySubstitution(unittest.TestCase):
self.assertEqual(self.sub.substitute_html(dammit.markup),
"&lsquo;&rsquo;foo&ldquo;&rdquo;")
def test_html5_entity(self):
# Some HTML5 entities correspond to single- or multi-character
# Unicode sequences.
for entity, u in (
# A few spot checks of our ability to recognize
# special character sequences and convert them
# to named entities.
('&models;', '\u22a7'),
('&Nfr;', '\U0001d511'),
('&ngeqq;', '\u2267\u0338'),
('&not;', '\xac'),
('&Not;', '\u2aec'),
# We _could_ convert | to &verbarr;, but we don't, because
# | is an ASCII character.
('|' '|'),
# Similarly for the fj ligature, which we could convert to
# &fjlig;, but we don't.
("fj", "fj"),
# We do convert _these_ ASCII characters to HTML entities,
# because that's required to generate valid HTML.
('&gt;', '>'),
('&lt;', '<'),
('&amp;', '&'),
):
template = '3 %s 4'
raw = template % u
with_entities = template % entity
self.assertEqual(self.sub.substitute_html(raw), with_entities)
def test_html5_entity_with_variation_selector(self):
# Some HTML5 entities correspond either to a single-character
# Unicode sequence _or_ to the same character plus U+FE00,
# VARIATION SELECTOR 1. We can handle this.
data = "fjords \u2294 penguins"
markup = "fjords &sqcup; penguins"
self.assertEqual(self.sub.substitute_html(data), markup)
data = "fjords \u2294\ufe00 penguins"
markup = "fjords &sqcups; penguins"
self.assertEqual(self.sub.substitute_html(data), markup)
def test_xml_converstion_includes_no_quotes_if_make_quoted_attribute_is_false(self):
s = 'Welcome to "my bar"'
self.assertEqual(self.sub.substitute_xml(s, False), s)
@ -350,186 +522,26 @@ class TestEncodingConversion(SoupTest):
markup = '<div><a \N{SNOWMAN}="snowman"></a></div>'
self.assertEqual(self.soup(markup).div.encode("utf8"), markup.encode("utf8"))
class TestUnicodeDammit(unittest.TestCase):
"""Standalone tests of UnicodeDammit."""
def test_unicode_input(self):
markup = "I'm already Unicode! \N{SNOWMAN}"
dammit = UnicodeDammit(markup)
self.assertEqual(dammit.unicode_markup, markup)
def test_smart_quotes_to_unicode(self):
markup = b"<foo>\x91\x92\x93\x94</foo>"
dammit = UnicodeDammit(markup)
self.assertEqual(
dammit.unicode_markup, "<foo>\u2018\u2019\u201c\u201d</foo>")
def test_smart_quotes_to_xml_entities(self):
markup = b"<foo>\x91\x92\x93\x94</foo>"
dammit = UnicodeDammit(markup, smart_quotes_to="xml")
self.assertEqual(
dammit.unicode_markup, "<foo>&#x2018;&#x2019;&#x201C;&#x201D;</foo>")
def test_smart_quotes_to_html_entities(self):
markup = b"<foo>\x91\x92\x93\x94</foo>"
dammit = UnicodeDammit(markup, smart_quotes_to="html")
self.assertEqual(
dammit.unicode_markup, "<foo>&lsquo;&rsquo;&ldquo;&rdquo;</foo>")
def test_smart_quotes_to_ascii(self):
markup = b"<foo>\x91\x92\x93\x94</foo>"
dammit = UnicodeDammit(markup, smart_quotes_to="ascii")
self.assertEqual(
dammit.unicode_markup, """<foo>''""</foo>""")
def test_detect_utf8(self):
utf8 = b"Sacr\xc3\xa9 bleu! \xe2\x98\x83"
dammit = UnicodeDammit(utf8)
self.assertEqual(dammit.original_encoding.lower(), 'utf-8')
self.assertEqual(dammit.unicode_markup, 'Sacr\xe9 bleu! \N{SNOWMAN}')
def test_convert_hebrew(self):
hebrew = b"\xed\xe5\xec\xf9"
dammit = UnicodeDammit(hebrew, ["iso-8859-8"])
self.assertEqual(dammit.original_encoding.lower(), 'iso-8859-8')
self.assertEqual(dammit.unicode_markup, '\u05dd\u05d5\u05dc\u05e9')
def test_dont_see_smart_quotes_where_there_are_none(self):
utf_8 = b"\343\202\261\343\203\274\343\202\277\343\202\244 Watch"
dammit = UnicodeDammit(utf_8)
self.assertEqual(dammit.original_encoding.lower(), 'utf-8')
self.assertEqual(dammit.unicode_markup.encode("utf-8"), utf_8)
def test_ignore_inappropriate_codecs(self):
utf8_data = "Räksmörgås".encode("utf-8")
dammit = UnicodeDammit(utf8_data, ["iso-8859-8"])
self.assertEqual(dammit.original_encoding.lower(), 'utf-8')
def test_ignore_invalid_codecs(self):
utf8_data = "Räksmörgås".encode("utf-8")
for bad_encoding in ['.utf8', '...', 'utF---16.!']:
dammit = UnicodeDammit(utf8_data, [bad_encoding])
self.assertEqual(dammit.original_encoding.lower(), 'utf-8')
def test_exclude_encodings(self):
# This is UTF-8.
utf8_data = "Räksmörgås".encode("utf-8")
# But if we exclude UTF-8 from consideration, the guess is
# Windows-1252.
dammit = UnicodeDammit(utf8_data, exclude_encodings=["utf-8"])
self.assertEqual(dammit.original_encoding.lower(), 'windows-1252')
# And if we exclude that, there is no valid guess at all.
dammit = UnicodeDammit(
utf8_data, exclude_encodings=["utf-8", "windows-1252"])
self.assertEqual(dammit.original_encoding, None)
def test_encoding_detector_replaces_junk_in_encoding_name_with_replacement_character(self):
detected = EncodingDetector(
b'<?xml version="1.0" encoding="UTF-\xdb" ?>')
encodings = list(detected.encodings)
assert 'utf-\N{REPLACEMENT CHARACTER}' in encodings
def test_detect_html5_style_meta_tag(self):
for data in (
b'<html><meta charset="euc-jp" /></html>',
b"<html><meta charset='euc-jp' /></html>",
b"<html><meta charset=euc-jp /></html>",
b"<html><meta charset=euc-jp/></html>"):
dammit = UnicodeDammit(data, is_html=True)
self.assertEqual(
"euc-jp", dammit.original_encoding)
def test_last_ditch_entity_replacement(self):
# This is a UTF-8 document that contains bytestrings
# completely incompatible with UTF-8 (ie. encoded with some other
# encoding).
#
# Since there is no consistent encoding for the document,
# Unicode, Dammit will eventually encode the document as UTF-8
# and encode the incompatible characters as REPLACEMENT
# CHARACTER.
#
# If chardet is installed, it will detect that the document
# can be converted into ISO-8859-1 without errors. This happens
# to be the wrong encoding, but it is a consistent encoding, so the
# code we're testing here won't run.
#
# So we temporarily disable chardet if it's present.
doc = b"""\357\273\277<?xml version="1.0" encoding="UTF-8"?>
<html><b>\330\250\330\252\330\261</b>
<i>\310\322\321\220\312\321\355\344</i></html>"""
chardet = bs4.dammit.chardet_dammit
logging.disable(logging.WARNING)
try:
def noop(str):
return None
bs4.dammit.chardet_dammit = noop
dammit = UnicodeDammit(doc)
self.assertEqual(True, dammit.contains_replacement_characters)
self.assertTrue("\ufffd" in dammit.unicode_markup)
soup = BeautifulSoup(doc, "html.parser")
self.assertTrue(soup.contains_replacement_characters)
finally:
logging.disable(logging.NOTSET)
bs4.dammit.chardet_dammit = chardet
def test_byte_order_mark_removed(self):
# A document written in UTF-16LE will have its byte order marker stripped.
data = b'\xff\xfe<\x00a\x00>\x00\xe1\x00\xe9\x00<\x00/\x00a\x00>\x00'
dammit = UnicodeDammit(data)
self.assertEqual("<a>áé</a>", dammit.unicode_markup)
self.assertEqual("utf-16le", dammit.original_encoding)
def test_detwingle(self):
# Here's a UTF8 document.
utf8 = ("\N{SNOWMAN}" * 3).encode("utf8")
# Here's a Windows-1252 document.
windows_1252 = (
"\N{LEFT DOUBLE QUOTATION MARK}Hi, I like Windows!"
"\N{RIGHT DOUBLE QUOTATION MARK}").encode("windows_1252")
# Through some unholy alchemy, they've been stuck together.
doc = utf8 + windows_1252 + utf8
# The document can't be turned into UTF-8:
self.assertRaises(UnicodeDecodeError, doc.decode, "utf8")
# Unicode, Dammit thinks the whole document is Windows-1252,
# and decodes it into "☃☃☃“Hi, I like Windows!”☃☃☃"
# But if we run it through fix_embedded_windows_1252, it's fixed:
fixed = UnicodeDammit.detwingle(doc)
self.assertEqual(
"☃☃☃“Hi, I like Windows!”☃☃☃", fixed.decode("utf8"))
def test_detwingle_ignores_multibyte_characters(self):
# Each of these characters has a UTF-8 representation ending
# in \x93. \x93 is a smart quote if interpreted as
# Windows-1252. But our code knows to skip over multibyte
# UTF-8 characters, so they'll survive the process unscathed.
for tricky_unicode_char in (
"\N{LATIN SMALL LIGATURE OE}", # 2-byte char '\xc5\x93'
"\N{LATIN SUBSCRIPT SMALL LETTER X}", # 3-byte char '\xe2\x82\x93'
"\xf0\x90\x90\x93", # This is a CJK character, not sure which one.
):
input = tricky_unicode_char.encode("utf8")
self.assertTrue(input.endswith(b'\x93'))
output = UnicodeDammit.detwingle(input)
self.assertEqual(output, input)
class TestNamedspacedAttribute(SoupTest):
def test_name_may_be_none(self):
def test_name_may_be_none_or_missing(self):
a = NamespacedAttribute("xmlns", None)
self.assertEqual(a, "xmlns")
a = NamespacedAttribute("xmlns", "")
self.assertEqual(a, "xmlns")
a = NamespacedAttribute("xmlns")
self.assertEqual(a, "xmlns")
def test_namespace_may_be_none_or_missing(self):
a = NamespacedAttribute(None, "tag")
self.assertEqual(a, "tag")
a = NamespacedAttribute("", "tag")
self.assertEqual(a, "tag")
def test_attribute_is_equivalent_to_colon_separated_string(self):
a = NamespacedAttribute("a", "b")
self.assertEqual("a:b", a)

View file

@ -27,13 +27,17 @@ from bs4.element import (
Doctype,
Formatter,
NavigableString,
Script,
SoupStrainer,
Stylesheet,
Tag,
TemplateString,
)
from bs4.testing import (
SoupTest,
skipIf,
)
from soupsieve import SelectorSyntaxError
XML_BUILDER_PRESENT = (builder_registry.lookup("xml") is not None)
LXML_PRESENT = (builder_registry.lookup("lxml") is not None)
@ -741,6 +745,30 @@ class TestPreviousSibling(SiblingTest):
self.assertEqual(start.find_previous_sibling(text="nonesuch"), None)
class TestTag(SoupTest):
# Test various methods of Tag.
def test__should_pretty_print(self):
# Test the rules about when a tag should be pretty-printed.
tag = self.soup("").new_tag("a_tag")
# No list of whitespace-preserving tags -> pretty-print
tag._preserve_whitespace_tags = None
self.assertEqual(True, tag._should_pretty_print(0))
# List exists but tag is not on the list -> pretty-print
tag.preserve_whitespace_tags = ["some_other_tag"]
self.assertEqual(True, tag._should_pretty_print(1))
# Indent level is None -> don't pretty-print
self.assertEqual(False, tag._should_pretty_print(None))
# Tag is on the whitespace-preserving list -> don't pretty-print
tag.preserve_whitespace_tags = ["some_other_tag", "a_tag"]
self.assertEqual(False, tag._should_pretty_print(1))
class TestTagCreation(SoupTest):
"""Test the ability to create new tags."""
def test_new_tag(self):
@ -981,6 +1009,15 @@ class TestTreeModification(SoupTest):
soup.a.extend(l)
self.assertEqual("<a><g></g><f></f><e></e><d></d><c></c><b></b></a>", soup.decode())
def test_extend_with_another_tags_contents(self):
data = '<body><div id="d1"><a>1</a><a>2</a><a>3</a><a>4</a></div><div id="d2"></div></body>'
soup = self.soup(data)
d1 = soup.find('div', id='d1')
d2 = soup.find('div', id='d2')
d2.extend(d1)
self.assertEqual('<div id="d1"></div>', d1.decode())
self.assertEqual('<div id="d2"><a>1</a><a>2</a><a>3</a><a>4</a></div>', d2.decode())
def test_move_tag_to_beginning_of_parent(self):
data = "<a><b></b><c></c><d></d></a>"
soup = self.soup(data)
@ -1093,6 +1130,37 @@ class TestTreeModification(SoupTest):
self.assertEqual(no.next_element, "no")
self.assertEqual(no.next_sibling, " business")
def test_replace_with_errors(self):
# Can't replace a tag that's not part of a tree.
a_tag = Tag(name="a")
self.assertRaises(ValueError, a_tag.replace_with, "won't work")
# Can't replace a tag with its parent.
a_tag = self.soup("<a><b></b></a>").a
self.assertRaises(ValueError, a_tag.b.replace_with, a_tag)
# Or with a list that includes its parent.
self.assertRaises(ValueError, a_tag.b.replace_with,
"string1", a_tag, "string2")
def test_replace_with_multiple(self):
data = "<a><b></b><c></c></a>"
soup = self.soup(data)
d_tag = soup.new_tag("d")
d_tag.string = "Text In D Tag"
e_tag = soup.new_tag("e")
f_tag = soup.new_tag("f")
a_string = "Random Text"
soup.c.replace_with(d_tag, e_tag, a_string, f_tag)
self.assertEqual(
"<a><b></b><d>Text In D Tag</d><e></e>Random Text<f></f></a>",
soup.decode()
)
assert soup.b.next_element == d_tag
assert d_tag.string.next_element==e_tag
assert e_tag.next_element.string == a_string
assert e_tag.next_element.next_element == f_tag
def test_replace_first_child(self):
data = "<a><b></b><c></c></a>"
soup = self.soup(data)
@ -1251,6 +1319,23 @@ class TestTreeModification(SoupTest):
a.clear(decompose=True)
self.assertEqual(0, len(em.contents))
def test_decompose(self):
# Test PageElement.decompose() and PageElement.decomposed
soup = self.soup("<p><a>String <em>Italicized</em></a></p><p>Another para</p>")
p1, p2 = soup.find_all('p')
a = p1.a
text = p1.em.string
for i in [p1, p2, a, text]:
self.assertEqual(False, i.decomposed)
# This sets p1 and everything beneath it to decomposed.
p1.decompose()
for i in [p1, a, text]:
self.assertEqual(True, i.decomposed)
# p2 is unaffected.
self.assertEqual(False, p2.decomposed)
def test_string_set(self):
"""Tag.string = 'string'"""
soup = self.soup("<a></a> <b><c></c></b>")
@ -1367,7 +1452,7 @@ class TestElementObjects(SoupTest):
self.assertEqual(soup.a.get_text(","), "a,r, , t ")
self.assertEqual(soup.a.get_text(",", strip=True), "a,r,t")
def test_get_text_ignores_comments(self):
def test_get_text_ignores_special_string_containers(self):
soup = self.soup("foo<!--IGNORE-->bar")
self.assertEqual(soup.get_text(), "foobar")
@ -1376,10 +1461,51 @@ class TestElementObjects(SoupTest):
self.assertEqual(
soup.get_text(types=None), "fooIGNOREbar")
def test_all_strings_ignores_comments(self):
soup = self.soup("foo<style>CSS</style><script>Javascript</script>bar")
self.assertEqual(soup.get_text(), "foobar")
def test_all_strings_ignores_special_string_containers(self):
soup = self.soup("foo<!--IGNORE-->bar")
self.assertEqual(['foo', 'bar'], list(soup.strings))
soup = self.soup("foo<style>CSS</style><script>Javascript</script>bar")
self.assertEqual(['foo', 'bar'], list(soup.strings))
def test_string_methods_inside_special_string_container_tags(self):
# Strings inside tags like <script> are generally ignored by
# methods like get_text, because they're not what humans
# consider 'text'. But if you call get_text on the <script>
# tag itself, those strings _are_ considered to be 'text',
# because there's nothing else you might be looking for.
style = self.soup("<div>a<style>Some CSS</style></div>")
template = self.soup("<div>a<template><p>Templated <b>text</b>.</p><!--With a comment.--></template></div>")
script = self.soup("<div>a<script><!--a comment-->Some text</script></div>")
self.assertEqual(style.div.get_text(), "a")
self.assertEqual(list(style.div.strings), ["a"])
self.assertEqual(style.div.style.get_text(), "Some CSS")
self.assertEqual(list(style.div.style.strings),
['Some CSS'])
# The comment is not picked up here. That's because it was
# parsed into a Comment object, which is not considered
# interesting by template.strings.
self.assertEqual(template.div.get_text(), "a")
self.assertEqual(list(template.div.strings), ["a"])
self.assertEqual(template.div.template.get_text(), "Templated text.")
self.assertEqual(list(template.div.template.strings),
["Templated ", "text", "."])
# The comment is included here, because it didn't get parsed
# into a Comment object--it's part of the Script string.
self.assertEqual(script.div.get_text(), "a")
self.assertEqual(list(script.div.strings), ["a"])
self.assertEqual(script.div.script.get_text(),
"<!--a comment-->Some text")
self.assertEqual(list(script.div.script.strings),
['<!--a comment-->Some text'])
class TestCDAtaListAttributes(SoupTest):
"""Testing cdata-list attributes like 'class'.
@ -1466,6 +1592,31 @@ class TestPersistence(SoupTest):
self.assertEqual("<p> </p>", str(copy))
self.assertEqual(encoding, copy.original_encoding)
def test_copy_preserves_builder_information(self):
tag = self.soup('<p></p>').p
# Simulate a tag obtained from a source file.
tag.sourceline = 10
tag.sourcepos = 33
copied = tag.__copy__()
# The TreeBuilder object is no longer availble, but information
# obtained from it gets copied over to the new Tag object.
self.assertEqual(tag.sourceline, copied.sourceline)
self.assertEqual(tag.sourcepos, copied.sourcepos)
self.assertEqual(
tag.can_be_empty_element, copied.can_be_empty_element
)
self.assertEqual(
tag.cdata_list_attributes, copied.cdata_list_attributes
)
self.assertEqual(
tag.preserve_whitespace_tags, copied.preserve_whitespace_tags
)
def test_unicode_pickle(self):
# A tree containing Unicode characters can be pickled.
html = "<b>\N{SNOWMAN}</b>"
@ -1726,71 +1877,7 @@ class TestEncoding(SoupTest):
else:
self.assertEqual(b'<b>\\u2603</b>', repr(soup))
class TestFormatter(SoupTest):
def test_sort_attributes(self):
# Test the ability to override Formatter.attributes() to,
# e.g., disable the normal sorting of attributes.
class UnsortedFormatter(Formatter):
def attributes(self, tag):
self.called_with = tag
for k, v in sorted(tag.attrs.items()):
if k == 'ignore':
continue
yield k,v
soup = self.soup('<p cval="1" aval="2" ignore="ignored"></p>')
formatter = UnsortedFormatter()
decoded = soup.decode(formatter=formatter)
# attributes() was called on the <p> tag. It filtered out one
# attribute and sorted the other two.
self.assertEqual(formatter.called_with, soup.p)
self.assertEqual('<p aval="2" cval="1"></p>', decoded)
class TestNavigableStringSubclasses(SoupTest):
def test_cdata(self):
# None of the current builders turn CDATA sections into CData
# objects, but you can create them manually.
soup = self.soup("")
cdata = CData("foo")
soup.insert(1, cdata)
self.assertEqual(str(soup), "<![CDATA[foo]]>")
self.assertEqual(soup.find(text="foo"), "foo")
self.assertEqual(soup.contents[0], "foo")
def test_cdata_is_never_formatted(self):
"""Text inside a CData object is passed into the formatter.
But the return value is ignored.
"""
self.count = 0
def increment(*args):
self.count += 1
return "BITTER FAILURE"
soup = self.soup("")
cdata = CData("<><><>")
soup.insert(1, cdata)
self.assertEqual(
b"<![CDATA[<><><>]]>", soup.encode(formatter=increment))
self.assertEqual(1, self.count)
def test_doctype_ends_in_newline(self):
# Unlike other NavigableString subclasses, a DOCTYPE always ends
# in a newline.
doctype = Doctype("foo")
soup = self.soup("")
soup.insert(1, doctype)
self.assertEqual(soup.encode(), b"<!DOCTYPE foo>\n")
def test_declaration(self):
d = Declaration("foo")
self.assertEqual("<?foo?>", d.output_ready())
class TestSoupSelector(TreeTest):
HTML = """
@ -1900,7 +1987,7 @@ class TestSoupSelector(TreeTest):
self.assertEqual(len(self.soup.select('del')), 0)
def test_invalid_tag(self):
self.assertRaises(SyntaxError, self.soup.select, 'tag%t')
self.assertRaises(SelectorSyntaxError, self.soup.select, 'tag%t')
def test_select_dashed_tag_ids(self):
self.assertSelects('custom-dashed-tag', ['dash1', 'dash2'])
@ -2091,7 +2178,7 @@ class TestSoupSelector(TreeTest):
NotImplementedError, self.soup.select, "a:no-such-pseudoclass")
self.assertRaises(
SyntaxError, self.soup.select, "a:nth-of-type(a)")
SelectorSyntaxError, self.soup.select, "a:nth-of-type(a)")
def test_nth_of_type(self):
# Try to select first paragraph
@ -2147,7 +2234,7 @@ class TestSoupSelector(TreeTest):
self.assertEqual([], self.soup.select('#inner ~ h2'))
def test_dangling_combinator(self):
self.assertRaises(SyntaxError, self.soup.select, 'h1 >')
self.assertRaises(SelectorSyntaxError, self.soup.select, 'h1 >')
def test_sibling_combinator_wont_select_same_tag_twice(self):
self.assertSelects('p[lang] ~ p', ['lang-en-gb', 'lang-en-us', 'lang-fr'])
@ -2178,8 +2265,8 @@ class TestSoupSelector(TreeTest):
self.assertSelects('div x,y, z', ['xid', 'yid', 'zida', 'zidb', 'zidab', 'zidac'])
def test_invalid_multiple_select(self):
self.assertRaises(SyntaxError, self.soup.select, ',x, y')
self.assertRaises(SyntaxError, self.soup.select, 'x,,y')
self.assertRaises(SelectorSyntaxError, self.soup.select, ',x, y')
self.assertRaises(SelectorSyntaxError, self.soup.select, 'x,,y')
def test_multiple_select_attrs(self):
self.assertSelects('p[lang=en], p[lang=en-gb]', ['lang-en', 'lang-en-gb'])

View file

@ -1,3 +1,3 @@
from .core import where
from .core import contents, where
__version__ = "2019.09.11"
__version__ = "2021.10.08"

View file

@ -1,2 +1,12 @@
from certifi import where
print(where())
import argparse
from certifi import contents, where
parser = argparse.ArgumentParser()
parser.add_argument("-c", "--contents", action="store_true")
args = parser.parse_args()
if args.contents:
print(contents())
else:
print(where())

Some files were not shown because too many files have changed in this diff Show more