mirror of
https://github.com/morpheus65535/bazarr.git
synced 2025-04-24 06:37:16 -04:00
Merge branch 'development' into python3
# Conflicts: # bazarr/database.py # bazarr/get_episodes.py # bazarr/get_movies.py # bazarr/get_series.py # bazarr/get_subtitle.py # bazarr/list_subtitles.py # bazarr/main.py # views/movie.tpl
This commit is contained in:
commit
fe6758858a
674 changed files with 4721 additions and 19904 deletions
|
@ -9,7 +9,7 @@ import tarfile
|
|||
from get_args import args
|
||||
from config import settings, bazarr_url
|
||||
from queueconfig import notifications
|
||||
from database import System
|
||||
from database import database
|
||||
|
||||
if not args.no_update and not args.release_update:
|
||||
import git
|
||||
|
@ -297,8 +297,8 @@ def updated(restart=True):
|
|||
try:
|
||||
from main import restart
|
||||
restart()
|
||||
except requests.ConnectionError:
|
||||
except:
|
||||
logging.info('BAZARR Restart failed, please restart Bazarr manualy')
|
||||
updated(restart=False)
|
||||
else:
|
||||
System.update({System.updated: 1}).execute()
|
||||
database.execute("UPDATE system SET updated='1'")
|
||||
|
|
88
bazarr/create_db.sql
Normal file
88
bazarr/create_db.sql
Normal file
|
@ -0,0 +1,88 @@
|
|||
BEGIN TRANSACTION;
|
||||
CREATE TABLE "table_shows" (
|
||||
`tvdbId` INTEGER NOT NULL UNIQUE,
|
||||
`title` TEXT NOT NULL,
|
||||
`path` TEXT NOT NULL UNIQUE,
|
||||
`languages` TEXT,
|
||||
`hearing_impaired` TEXT,
|
||||
`sonarrSeriesId` INTEGER NOT NULL UNIQUE,
|
||||
`overview` TEXT,
|
||||
`poster` TEXT,
|
||||
`fanart` TEXT,
|
||||
`audio_language` "text",
|
||||
`sortTitle` "text",
|
||||
PRIMARY KEY(`tvdbId`)
|
||||
);
|
||||
CREATE TABLE "table_settings_providers" (
|
||||
`name` TEXT NOT NULL UNIQUE,
|
||||
`enabled` INTEGER,
|
||||
`username` "text",
|
||||
`password` "text",
|
||||
PRIMARY KEY(`name`)
|
||||
);
|
||||
CREATE TABLE "table_settings_notifier" (
|
||||
`name` TEXT,
|
||||
`url` TEXT,
|
||||
`enabled` INTEGER,
|
||||
PRIMARY KEY(`name`)
|
||||
);
|
||||
CREATE TABLE "table_settings_languages" (
|
||||
`code3` TEXT NOT NULL UNIQUE,
|
||||
`code2` TEXT,
|
||||
`name` TEXT NOT NULL,
|
||||
`enabled` INTEGER,
|
||||
`code3b` TEXT,
|
||||
PRIMARY KEY(`code3`)
|
||||
);
|
||||
CREATE TABLE "table_history" (
|
||||
`id` INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT UNIQUE,
|
||||
`action` INTEGER NOT NULL,
|
||||
`sonarrSeriesId` INTEGER NOT NULL,
|
||||
`sonarrEpisodeId` INTEGER NOT NULL,
|
||||
`timestamp` INTEGER NOT NULL,
|
||||
`description` TEXT NOT NULL
|
||||
);
|
||||
CREATE TABLE "table_episodes" (
|
||||
`sonarrSeriesId` INTEGER NOT NULL,
|
||||
`sonarrEpisodeId` INTEGER NOT NULL UNIQUE,
|
||||
`title` TEXT NOT NULL,
|
||||
`path` TEXT NOT NULL,
|
||||
`season` INTEGER NOT NULL,
|
||||
`episode` INTEGER NOT NULL,
|
||||
`subtitles` TEXT,
|
||||
`missing_subtitles` TEXT,
|
||||
`scene_name` TEXT,
|
||||
`monitored` TEXT,
|
||||
`failedAttempts` "text"
|
||||
);
|
||||
CREATE TABLE "table_movies" (
|
||||
`tmdbId` TEXT NOT NULL UNIQUE,
|
||||
`title` TEXT NOT NULL,
|
||||
`path` TEXT NOT NULL UNIQUE,
|
||||
`languages` TEXT,
|
||||
`subtitles` TEXT,
|
||||
`missing_subtitles` TEXT,
|
||||
`hearing_impaired` TEXT,
|
||||
`radarrId` INTEGER NOT NULL UNIQUE,
|
||||
`overview` TEXT,
|
||||
`poster` TEXT,
|
||||
`fanart` TEXT,
|
||||
`audio_language` "text",
|
||||
`sceneName` TEXT,
|
||||
`monitored` TEXT,
|
||||
`failedAttempts` "text",
|
||||
PRIMARY KEY(`tmdbId`)
|
||||
);
|
||||
CREATE TABLE "table_history_movie" (
|
||||
`id` INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT UNIQUE,
|
||||
`action` INTEGER NOT NULL,
|
||||
`radarrId` INTEGER NOT NULL,
|
||||
`timestamp` INTEGER NOT NULL,
|
||||
`description` TEXT NOT NULL
|
||||
);
|
||||
CREATE TABLE "system" (
|
||||
`configured` TEXT,
|
||||
`updated` TEXT
|
||||
);
|
||||
INSERT INTO `system` (configured, updated) VALUES ('0', '0');
|
||||
COMMIT;
|
|
@ -1,215 +1,129 @@
|
|||
from __future__ import absolute_import
|
||||
import os
|
||||
import atexit
|
||||
from sqlite3worker import Sqlite3Worker
|
||||
from six import string_types
|
||||
|
||||
from get_args import args
|
||||
from peewee import *
|
||||
from playhouse.sqliteq import SqliteQueueDatabase
|
||||
from playhouse.migrate import *
|
||||
|
||||
database = SqliteQueueDatabase(
|
||||
os.path.join(args.config_dir, 'db', 'bazarr.db'),
|
||||
use_gevent=False,
|
||||
autostart=True,
|
||||
queue_max_size=256, # Max. # of pending writes that can accumulate.
|
||||
results_timeout=30.0 # Max. time to wait for query to be executed.
|
||||
)
|
||||
|
||||
database.pragma('wal_checkpoint', 'TRUNCATE') # Run a checkpoint and merge remaining wal-journal.
|
||||
database.cache_size = -1024 # Number of KB of cache for wal-journal.
|
||||
# Must be negative because positive means number of pages.
|
||||
database.wal_autocheckpoint = 50 # Run an automatic checkpoint every 50 write transactions.
|
||||
from helper import path_replace, path_replace_movie, path_replace_reverse, path_replace_reverse_movie
|
||||
|
||||
|
||||
@database.func('path_substitution')
|
||||
def path_substitution(path):
|
||||
from helper import path_replace
|
||||
return path_replace(path)
|
||||
def db_init():
|
||||
import sqlite3
|
||||
import os
|
||||
import logging
|
||||
|
||||
from get_args import args
|
||||
|
||||
if not os.path.exists(os.path.join(args.config_dir, 'db', 'bazarr.db')):
|
||||
# Get SQL script from file
|
||||
fd = open(os.path.join(os.path.dirname(__file__), 'create_db.sql'), 'r')
|
||||
script = fd.read()
|
||||
# Close SQL script file
|
||||
fd.close()
|
||||
# Open database connection
|
||||
db = sqlite3.connect(os.path.join(args.config_dir, 'db', 'bazarr.db'), timeout=30)
|
||||
c = db.cursor()
|
||||
# Execute script and commit change to database
|
||||
c.executescript(script)
|
||||
# Close database connection
|
||||
db.close()
|
||||
logging.info('BAZARR Database created successfully')
|
||||
|
||||
|
||||
@database.func('path_substitution_movie')
|
||||
def path_substitution_movie(path):
|
||||
from helper import path_replace_movie
|
||||
return path_replace_movie(path)
|
||||
database = Sqlite3Worker(os.path.join(args.config_dir, 'db', 'bazarr.db'), max_queue_size=256, as_dict=True)
|
||||
|
||||
|
||||
class UnknownField(object):
|
||||
def __init__(self, *_, **__): pass
|
||||
class SqliteDictConverter:
|
||||
def __init__(self):
|
||||
self.keys_insert = tuple()
|
||||
self.keys_update = tuple()
|
||||
self.values = tuple()
|
||||
self.question_marks = tuple()
|
||||
|
||||
class BaseModel(Model):
|
||||
class Meta:
|
||||
database = database
|
||||
def convert(self, values_dict):
|
||||
if type(values_dict) is dict:
|
||||
self.keys_insert = tuple()
|
||||
self.keys_update = tuple()
|
||||
self.values = tuple()
|
||||
self.question_marks = tuple()
|
||||
|
||||
temp_keys = list()
|
||||
temp_values = list()
|
||||
for item in values_dict.items():
|
||||
temp_keys.append(item[0])
|
||||
temp_values.append(item[1])
|
||||
self.keys_insert = ','.join(temp_keys)
|
||||
self.keys_update = ','.join([k + '=?' for k in temp_keys])
|
||||
self.values = tuple(temp_values)
|
||||
self.question_marks = ','.join(list('?'*len(values_dict)))
|
||||
return self
|
||||
else:
|
||||
pass
|
||||
|
||||
|
||||
class System(BaseModel):
|
||||
configured = TextField(null=True)
|
||||
updated = TextField(null=True)
|
||||
|
||||
class Meta:
|
||||
table_name = 'system'
|
||||
primary_key = False
|
||||
dict_converter = SqliteDictConverter()
|
||||
|
||||
|
||||
class TableShows(BaseModel):
|
||||
alternate_titles = TextField(column_name='alternateTitles', null=True)
|
||||
audio_language = TextField(null=True)
|
||||
fanart = TextField(null=True)
|
||||
forced = TextField(null=True, constraints=[SQL('DEFAULT "False"')])
|
||||
hearing_impaired = TextField(null=True)
|
||||
languages = TextField(null=True)
|
||||
overview = TextField(null=True)
|
||||
path = TextField(null=False, unique=True)
|
||||
poster = TextField(null=True)
|
||||
sonarr_series_id = IntegerField(column_name='sonarrSeriesId', null=True, unique=True)
|
||||
sort_title = TextField(column_name='sortTitle', null=True)
|
||||
title = TextField(null=True)
|
||||
tvdb_id = IntegerField(column_name='tvdbId', null=True, unique=True, primary_key=True)
|
||||
year = TextField(null=True)
|
||||
class SqliteDictPathMapper:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
class Meta:
|
||||
table_name = 'table_shows'
|
||||
def path_replace(self, values_dict):
|
||||
if type(values_dict) is list:
|
||||
for item in values_dict:
|
||||
item['path'] = path_replace(item['path'])
|
||||
elif type(values_dict) is dict:
|
||||
values_dict['path'] = path_replace(values_dict['path'])
|
||||
else:
|
||||
return path_replace(values_dict)
|
||||
|
||||
def path_replace_movie(self, values_dict):
|
||||
if type(values_dict) is list:
|
||||
for item in values_dict:
|
||||
item['path'] = path_replace_movie(item['path'])
|
||||
elif type(values_dict) is dict:
|
||||
values_dict['path'] = path_replace_movie(values_dict['path'])
|
||||
else:
|
||||
return path_replace(values_dict)
|
||||
|
||||
|
||||
class TableEpisodes(BaseModel):
|
||||
rowid = IntegerField()
|
||||
audio_codec = TextField(null=True)
|
||||
episode = IntegerField(null=False)
|
||||
failed_attempts = TextField(column_name='failedAttempts', null=True)
|
||||
format = TextField(null=True)
|
||||
missing_subtitles = TextField(null=True)
|
||||
monitored = TextField(null=True)
|
||||
path = TextField(null=False)
|
||||
resolution = TextField(null=True)
|
||||
scene_name = TextField(null=True)
|
||||
season = IntegerField(null=False)
|
||||
sonarr_episode_id = IntegerField(column_name='sonarrEpisodeId', unique=True, null=False)
|
||||
sonarr_series_id = ForeignKeyField(TableShows, field='sonarr_series_id', column_name='sonarrSeriesId', null=False)
|
||||
subtitles = TextField(null=True)
|
||||
title = TextField(null=True)
|
||||
video_codec = TextField(null=True)
|
||||
episode_file_id = IntegerField(null=True)
|
||||
|
||||
class Meta:
|
||||
table_name = 'table_episodes'
|
||||
primary_key = False
|
||||
dict_mapper = SqliteDictPathMapper()
|
||||
|
||||
|
||||
class TableMovies(BaseModel):
|
||||
rowid = IntegerField()
|
||||
alternative_titles = TextField(column_name='alternativeTitles', null=True)
|
||||
audio_codec = TextField(null=True)
|
||||
audio_language = TextField(null=True)
|
||||
failed_attempts = TextField(column_name='failedAttempts', null=True)
|
||||
fanart = TextField(null=True)
|
||||
forced = TextField(null=True, constraints=[SQL('DEFAULT "False"')])
|
||||
format = TextField(null=True)
|
||||
hearing_impaired = TextField(null=True)
|
||||
imdb_id = TextField(column_name='imdbId', null=True)
|
||||
languages = TextField(null=True)
|
||||
missing_subtitles = TextField(null=True)
|
||||
monitored = TextField(null=True)
|
||||
overview = TextField(null=True)
|
||||
path = TextField(unique=True)
|
||||
poster = TextField(null=True)
|
||||
radarr_id = IntegerField(column_name='radarrId', null=False, unique=True)
|
||||
resolution = TextField(null=True)
|
||||
scene_name = TextField(column_name='sceneName', null=True)
|
||||
sort_title = TextField(column_name='sortTitle', null=True)
|
||||
subtitles = TextField(null=True)
|
||||
title = TextField(null=False)
|
||||
tmdb_id = TextField(column_name='tmdbId', primary_key=True, null=False)
|
||||
video_codec = TextField(null=True)
|
||||
year = TextField(null=True)
|
||||
movie_file_id = IntegerField(null=True)
|
||||
def db_upgrade():
|
||||
columnToAdd = [
|
||||
['table_shows', 'year', 'text'],
|
||||
['table_shows', 'alternateTitles', 'text'],
|
||||
['table_shows', 'forced', 'text', 'False'],
|
||||
['table_episodes', 'format', 'text'],
|
||||
['table_episodes', 'resolution', 'text'],
|
||||
['table_episodes', 'video_codec', 'text'],
|
||||
['table_episodes', 'audio_codec', 'text'],
|
||||
['table_episodes', 'episode_file_id', 'integer'],
|
||||
['table_movies', 'sortTitle', 'text'],
|
||||
['table_movies', 'year', 'text'],
|
||||
['table_movies', 'alternativeTitles', 'text'],
|
||||
['table_movies', 'format', 'text'],
|
||||
['table_movies', 'resolution', 'text'],
|
||||
['table_movies', 'video_codec', 'text'],
|
||||
['table_movies', 'audio_codec', 'text'],
|
||||
['table_movies', 'imdbId', 'text'],
|
||||
['table_movies', 'forced', 'text', 'False'],
|
||||
['table_movies', 'movie_file_id', 'integer'],
|
||||
['table_history', 'video_path', 'text'],
|
||||
['table_history', 'language', 'text'],
|
||||
['table_history', 'provider', 'text'],
|
||||
['table_history', 'score', 'text'],
|
||||
['table_history_movie', 'video_path', 'text'],
|
||||
['table_history_movie', 'language', 'text'],
|
||||
['table_history_movie', 'provider', 'text'],
|
||||
['table_history_movie', 'score', 'text']
|
||||
]
|
||||
|
||||
class Meta:
|
||||
table_name = 'table_movies'
|
||||
|
||||
|
||||
class TableHistory(BaseModel):
|
||||
id = PrimaryKeyField(null=False)
|
||||
action = IntegerField(null=False)
|
||||
description = TextField(null=False)
|
||||
language = TextField(null=True)
|
||||
provider = TextField(null=True)
|
||||
score = TextField(null=True)
|
||||
sonarr_episode_id = ForeignKeyField(TableEpisodes, field='sonarr_episode_id', column_name='sonarrEpisodeId', null=False)
|
||||
sonarr_series_id = ForeignKeyField(TableShows, field='sonarr_series_id', column_name='sonarrSeriesId', null=False)
|
||||
timestamp = IntegerField(null=False)
|
||||
video_path = TextField(null=True)
|
||||
|
||||
class Meta:
|
||||
table_name = 'table_history'
|
||||
|
||||
|
||||
class TableHistoryMovie(BaseModel):
|
||||
id = PrimaryKeyField(null=False)
|
||||
action = IntegerField(null=False)
|
||||
description = TextField(null=False)
|
||||
language = TextField(null=True)
|
||||
provider = TextField(null=True)
|
||||
radarr_id = ForeignKeyField(TableMovies, field='radarr_id', column_name='radarrId', null=False)
|
||||
score = TextField(null=True)
|
||||
timestamp = IntegerField(null=False)
|
||||
video_path = TextField(null=True)
|
||||
|
||||
class Meta:
|
||||
table_name = 'table_history_movie'
|
||||
|
||||
|
||||
class TableSettingsLanguages(BaseModel):
|
||||
code2 = TextField(null=False)
|
||||
code3 = TextField(null=False, unique=True, primary_key=True)
|
||||
code3b = TextField(null=True)
|
||||
enabled = IntegerField(null=True)
|
||||
name = TextField(null=False)
|
||||
|
||||
class Meta:
|
||||
table_name = 'table_settings_languages'
|
||||
|
||||
|
||||
class TableSettingsNotifier(BaseModel):
|
||||
enabled = IntegerField(null=False)
|
||||
name = TextField(null=False, primary_key=True)
|
||||
url = TextField(null=True)
|
||||
|
||||
class Meta:
|
||||
table_name = 'table_settings_notifier'
|
||||
|
||||
|
||||
# Database tables creation if they don't exists
|
||||
models_list = [TableShows, TableEpisodes, TableMovies, TableHistory, TableHistoryMovie, TableSettingsLanguages,
|
||||
TableSettingsNotifier, System]
|
||||
database.create_tables(models_list, safe=True)
|
||||
|
||||
|
||||
# Database migration
|
||||
migrator = SqliteMigrator(database)
|
||||
|
||||
# TableShows migration
|
||||
table_shows_columns = []
|
||||
for column in database.get_columns('table_shows'):
|
||||
table_shows_columns.append(column.name)
|
||||
if 'forced' not in table_shows_columns:
|
||||
migrate(migrator.add_column('table_shows', 'forced', TableShows.forced))
|
||||
|
||||
# TableEpisodes migration
|
||||
table_episodes_columns = []
|
||||
for column in database.get_columns('table_episodes'):
|
||||
table_episodes_columns.append(column.name)
|
||||
if 'episode_file_id' not in table_episodes_columns:
|
||||
migrate(migrator.add_column('table_episodes', 'episode_file_id', TableEpisodes.episode_file_id))
|
||||
|
||||
# TableMovies migration
|
||||
table_movies_columns = []
|
||||
for column in database.get_columns('table_movies'):
|
||||
table_movies_columns.append(column.name)
|
||||
if 'forced' not in table_movies_columns:
|
||||
migrate(migrator.add_column('table_movies', 'forced', TableMovies.forced))
|
||||
if 'movie_file_id' not in table_movies_columns:
|
||||
migrate(migrator.add_column('table_movies', 'movie_file_id', TableMovies.movie_file_id))
|
||||
|
||||
|
||||
def wal_cleaning():
|
||||
database.pragma('wal_checkpoint', 'TRUNCATE') # Run a checkpoint and merge remaining wal-journal.
|
||||
database.wal_autocheckpoint = 50 # Run an automatic checkpoint every 50 write transactions.
|
||||
for column in columnToAdd:
|
||||
try:
|
||||
if len(column) == 3:
|
||||
database.execute('''ALTER TABLE {0} ADD COLUMN "{1}" "{2}"'''.format(column[0], column[1], column[2]))
|
||||
else:
|
||||
database.execute('''ALTER TABLE {0} ADD COLUMN "{1}" "{2}" DEFAULT "{3}"'''.format(column[0], column[1], column[2], column[3]))
|
||||
except:
|
||||
pass
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
from __future__ import absolute_import
|
||||
import enzyme
|
||||
from enzyme.exceptions import MalformedMKVError
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
|
@ -30,7 +31,10 @@ class EmbeddedSubsReader:
|
|||
else:
|
||||
if os.path.splitext(file)[1] == '.mkv':
|
||||
with open(file, 'rb') as f:
|
||||
mkv = enzyme.MKV(f)
|
||||
try:
|
||||
mkv = enzyme.MKV(f)
|
||||
except MalformedMKVError:
|
||||
logging.error('BAZARR cannot analyze this MKV with our built-in MKV parser, you should install ffmpeg: ' + file)
|
||||
for subtitle_track in mkv.subtitle_tracks:
|
||||
subtitles_list.append([subtitle_track.language, subtitle_track.forced, subtitle_track.codec_id])
|
||||
|
||||
|
|
|
@ -5,7 +5,7 @@ import requests
|
|||
import logging
|
||||
import re
|
||||
from queueconfig import notifications
|
||||
from database import database, TableShows, TableEpisodes, wal_cleaning
|
||||
from database import database, dict_converter
|
||||
|
||||
from get_args import args
|
||||
from config import settings, url_sonarr
|
||||
|
@ -17,7 +17,6 @@ from get_subtitle import episode_download_subtitles
|
|||
def update_all_episodes():
|
||||
series_full_scan_subtitles()
|
||||
logging.info('BAZARR All existing episode subtitles indexed from disk.')
|
||||
wal_cleaning()
|
||||
|
||||
|
||||
def sync_episodes():
|
||||
|
@ -26,13 +25,9 @@ def sync_episodes():
|
|||
apikey_sonarr = settings.sonarr.apikey
|
||||
|
||||
# Get current episodes id in DB
|
||||
current_episodes_db = TableEpisodes.select(
|
||||
TableEpisodes.sonarr_episode_id,
|
||||
TableEpisodes.path,
|
||||
TableEpisodes.sonarr_series_id
|
||||
)
|
||||
current_episodes_db = database.execute("SELECT sonarrEpisodeId, path, sonarrSeriesId FROM table_episodes")
|
||||
|
||||
current_episodes_db_list = [x.sonarr_episode_id for x in current_episodes_db]
|
||||
current_episodes_db_list = [x['sonarrEpisodeId'] for x in current_episodes_db]
|
||||
|
||||
current_episodes_sonarr = []
|
||||
episodes_to_update = []
|
||||
|
@ -40,16 +35,13 @@ def sync_episodes():
|
|||
altered_episodes = []
|
||||
|
||||
# Get sonarrId for each series from database
|
||||
seriesIdList = TableShows.select(
|
||||
TableShows.sonarr_series_id,
|
||||
TableShows.title
|
||||
)
|
||||
seriesIdList = database.execute("SELECT sonarrSeriesId, title FROM table_shows")
|
||||
|
||||
seriesIdListLength = seriesIdList.count()
|
||||
seriesIdListLength = len(seriesIdList)
|
||||
for i, seriesId in enumerate(seriesIdList, 1):
|
||||
notifications.write(msg='Getting episodes data from Sonarr...', queue='get_episodes', item=i, length=seriesIdListLength)
|
||||
# Get episodes data for a series from Sonarr
|
||||
url_sonarr_api_episode = url_sonarr + "/api/episode?seriesId=" + str(seriesId.sonarr_series_id) + "&apikey=" + apikey_sonarr
|
||||
url_sonarr_api_episode = url_sonarr + "/api/episode?seriesId=" + str(seriesId['sonarrSeriesId']) + "&apikey=" + apikey_sonarr
|
||||
try:
|
||||
r = requests.get(url_sonarr_api_episode, timeout=60, verify=False)
|
||||
r.raise_for_status()
|
||||
|
@ -104,8 +96,8 @@ def sync_episodes():
|
|||
current_episodes_sonarr.append(episode['id'])
|
||||
|
||||
if episode['id'] in current_episodes_db_list:
|
||||
episodes_to_update.append({'sonarr_series_id': episode['seriesId'],
|
||||
'sonarr_episode_id': episode['id'],
|
||||
episodes_to_update.append({'sonarrSeriesId': episode['seriesId'],
|
||||
'sonarrEpisodeId': episode['id'],
|
||||
'title': episode['title'],
|
||||
'path': episode['episodeFile']['path'],
|
||||
'season': episode['seasonNumber'],
|
||||
|
@ -118,8 +110,8 @@ def sync_episodes():
|
|||
'audio_codec': audioCodec,
|
||||
'episode_file_id': episode['episodeFile']['id']})
|
||||
else:
|
||||
episodes_to_add.append({'sonarr_series_id': episode['seriesId'],
|
||||
'sonarr_episode_id': episode['id'],
|
||||
episodes_to_add.append({'sonarrSeriesId': episode['seriesId'],
|
||||
'sonarrEpisodeId': episode['id'],
|
||||
'title': episode['title'],
|
||||
'path': episode['episodeFile']['path'],
|
||||
'season': episode['seasonNumber'],
|
||||
|
@ -134,21 +126,9 @@ def sync_episodes():
|
|||
|
||||
# Update existing episodes in DB
|
||||
episode_in_db_list = []
|
||||
episodes_in_db = TableEpisodes.select(
|
||||
TableEpisodes.sonarr_series_id,
|
||||
TableEpisodes.sonarr_episode_id,
|
||||
TableEpisodes.title,
|
||||
TableEpisodes.path,
|
||||
TableEpisodes.season,
|
||||
TableEpisodes.episode,
|
||||
TableEpisodes.scene_name,
|
||||
TableEpisodes.monitored,
|
||||
TableEpisodes.format,
|
||||
TableEpisodes.resolution,
|
||||
TableEpisodes.video_codec,
|
||||
TableEpisodes.audio_codec,
|
||||
TableEpisodes.episode_file_id
|
||||
).dicts()
|
||||
episodes_in_db = database.execute("SELECT sonarrSeriesId, sonarrEpisodeId, title, path, season, episode, "
|
||||
"scene_name, monitored, format, resolution, video_codec, audio_codec, "
|
||||
"episode_file_id FROM table_episodes")
|
||||
|
||||
for item in episodes_in_db:
|
||||
episode_in_db_list.append(item)
|
||||
|
@ -156,30 +136,27 @@ def sync_episodes():
|
|||
episodes_to_update_list = [i for i in episodes_to_update if i not in episode_in_db_list]
|
||||
|
||||
for updated_episode in episodes_to_update_list:
|
||||
TableEpisodes.update(
|
||||
updated_episode
|
||||
).where(
|
||||
TableEpisodes.sonarr_episode_id == updated_episode['sonarr_episode_id']
|
||||
).execute()
|
||||
altered_episodes.append([updated_episode['sonarr_episode_id'],
|
||||
query = dict_converter.convert(updated_episode)
|
||||
database.execute('''UPDATE table_episodes SET ''' + query.keys_update + ''' WHERE sonarrEpisodeId = ?''',
|
||||
query.values + (updated_episode['sonarrEpisodeId'],))
|
||||
altered_episodes.append([updated_episode['sonarrEpisodeId'],
|
||||
updated_episode['path'],
|
||||
updated_episode['sonarr_series_id']])
|
||||
updated_episode['sonarrSeriesId']])
|
||||
|
||||
# Insert new episodes in DB
|
||||
for added_episode in episodes_to_add:
|
||||
TableEpisodes.insert(
|
||||
added_episode
|
||||
).on_conflict_ignore().execute()
|
||||
altered_episodes.append([added_episode['sonarr_episode_id'],
|
||||
query = dict_converter.convert(added_episode)
|
||||
database.execute(
|
||||
'''INSERT OR IGNORE INTO table_episodes(''' + query.keys_insert + ''') VALUES(''' + query.question_marks +
|
||||
''')''', query.values)
|
||||
altered_episodes.append([added_episode['sonarrEpisodeId'],
|
||||
added_episode['path']])
|
||||
|
||||
# Remove old episodes from DB
|
||||
removed_episodes = list(set(current_episodes_db_list) - set(current_episodes_sonarr))
|
||||
|
||||
for removed_episode in removed_episodes:
|
||||
TableEpisodes.delete().where(
|
||||
TableEpisodes.sonarr_episode_id == removed_episode
|
||||
).execute()
|
||||
database.execute("DELETE FROM table_episodes WHERE sonarrEpisodeId=?", (removed_episode,))
|
||||
|
||||
# Store subtitles for added or modified episodes
|
||||
for i, altered_episode in enumerate(altered_episodes, 1):
|
||||
|
|
|
@ -6,7 +6,7 @@ import pycountry
|
|||
|
||||
from get_args import args
|
||||
from subzero.language import Language
|
||||
from database import TableSettingsLanguages
|
||||
from database import database
|
||||
|
||||
|
||||
def load_language_in_db():
|
||||
|
@ -17,17 +17,11 @@ def load_language_in_db():
|
|||
|
||||
# Insert languages in database table
|
||||
for lang in langs:
|
||||
TableSettingsLanguages.insert(
|
||||
lang
|
||||
).on_conflict_ignore().execute()
|
||||
database.execute("INSERT OR IGNORE INTO table_settings_languages (code3, code2, name) VALUES (?, ?, ?)",
|
||||
(lang['code3'], lang['code2'], lang['name']))
|
||||
|
||||
TableSettingsLanguages.insert(
|
||||
{
|
||||
TableSettingsLanguages.code3: 'pob',
|
||||
TableSettingsLanguages.code2: 'pb',
|
||||
TableSettingsLanguages.name: 'Brazilian Portuguese'
|
||||
}
|
||||
).on_conflict_ignore().execute()
|
||||
database.execute("INSERT OR IGNORE INTO table_settings_languages (code3, code2, name) "
|
||||
"VALUES ('pob', 'pb', 'Brazilian Portuguese')")
|
||||
|
||||
langs = [{'code3b': lang.bibliographic, 'code3': lang.alpha_3}
|
||||
for lang in pycountry.languages
|
||||
|
@ -35,85 +29,49 @@ def load_language_in_db():
|
|||
|
||||
# Update languages in database table
|
||||
for lang in langs:
|
||||
TableSettingsLanguages.update(
|
||||
{
|
||||
TableSettingsLanguages.code3b: lang['code3b']
|
||||
}
|
||||
).where(
|
||||
TableSettingsLanguages.code3 == lang['code3']
|
||||
).execute()
|
||||
database.execute("UPDATE table_settings_languages SET code3b=? WHERE code3=?", (lang['code3b'], lang['code3']))
|
||||
|
||||
|
||||
def language_from_alpha2(lang):
|
||||
result = TableSettingsLanguages.select(
|
||||
TableSettingsLanguages.name
|
||||
).where(
|
||||
TableSettingsLanguages.code2 == lang
|
||||
).first()
|
||||
return result.name
|
||||
result = database.execute("SELECT name FROM table_settings_languages WHERE code2=?", (lang,))
|
||||
return result[0]['name'] or None
|
||||
|
||||
|
||||
def language_from_alpha3(lang):
|
||||
result = TableSettingsLanguages.select(
|
||||
TableSettingsLanguages.name
|
||||
).where(
|
||||
(TableSettingsLanguages.code3 == lang) |
|
||||
(TableSettingsLanguages.code3b == lang)
|
||||
).first()
|
||||
return result.name
|
||||
result = database.execute("SELECT name FROM table_settings_languages WHERE code3=? or code3b=?", (lang, lang))
|
||||
return result[0]['name'] or None
|
||||
|
||||
|
||||
def alpha2_from_alpha3(lang):
|
||||
result = TableSettingsLanguages.select(
|
||||
TableSettingsLanguages.code2
|
||||
).where(
|
||||
(TableSettingsLanguages.code3 == lang) |
|
||||
(TableSettingsLanguages.code3b == lang)
|
||||
).first()
|
||||
return result.code2
|
||||
result = database.execute("SELECT code2 FROM table_settings_languages WHERE code3=? or code3b=?", (lang, lang))
|
||||
return result[0]['code2'] or None
|
||||
|
||||
|
||||
def alpha2_from_language(lang):
|
||||
result = TableSettingsLanguages.select(
|
||||
TableSettingsLanguages.code2
|
||||
).where(
|
||||
TableSettingsLanguages.name == lang
|
||||
).first()
|
||||
return result.code2
|
||||
result = database.execute("SELECT code2 FROM table_settings_languages WHERE name=?", (lang,))
|
||||
return result[0]['code2'] or None
|
||||
|
||||
|
||||
def alpha3_from_alpha2(lang):
|
||||
result = TableSettingsLanguages.select(
|
||||
TableSettingsLanguages.code3
|
||||
).where(
|
||||
TableSettingsLanguages.code2 == lang
|
||||
).first()
|
||||
return result.code3
|
||||
result = database.execute("SELECT code3 FROM table_settings_languages WHERE code2=?", (lang,))
|
||||
return result[0]['code3'] or None
|
||||
|
||||
|
||||
def alpha3_from_language(lang):
|
||||
result = TableSettingsLanguages.select(
|
||||
TableSettingsLanguages.code3
|
||||
).where(
|
||||
TableSettingsLanguages.name == lang
|
||||
).first()
|
||||
return result.code3
|
||||
result = database.execute("SELECT code3 FROM table_settings_languages WHERE name=?", (lang,))
|
||||
return result[0]['code3'] or None
|
||||
|
||||
|
||||
def get_language_set():
|
||||
languages = TableSettingsLanguages.select(
|
||||
TableSettingsLanguages.code3
|
||||
).where(
|
||||
TableSettingsLanguages.enabled == 1
|
||||
)
|
||||
languages = database.execute("SELECT code3 FROM table_settings_languages WHERE enabled=1")
|
||||
|
||||
language_set = set()
|
||||
|
||||
for lang in languages:
|
||||
if lang.code3 == 'pob':
|
||||
if lang['code3'] == 'pob':
|
||||
language_set.add(Language('por', 'BR'))
|
||||
else:
|
||||
language_set.add(Language(lang.code3))
|
||||
language_set.add(Language(lang['code3']))
|
||||
|
||||
return language_set
|
||||
|
||||
|
|
|
@ -13,14 +13,12 @@ from utils import get_radarr_version
|
|||
from list_subtitles import store_subtitles_movie, list_missing_subtitles_movies, movies_full_scan_subtitles
|
||||
|
||||
from get_subtitle import movies_download_subtitles
|
||||
from database import TableMovies, wal_cleaning
|
||||
import six
|
||||
from database import database, dict_converter
|
||||
|
||||
|
||||
def update_all_movies():
|
||||
movies_full_scan_subtitles()
|
||||
logging.info('BAZARR All existing movie subtitles indexed from disk.')
|
||||
wal_cleaning()
|
||||
|
||||
|
||||
def update_movies():
|
||||
|
@ -55,13 +53,9 @@ def update_movies():
|
|||
return
|
||||
else:
|
||||
# Get current movies in DB
|
||||
current_movies_db = TableMovies.select(
|
||||
TableMovies.tmdb_id,
|
||||
TableMovies.path,
|
||||
TableMovies.radarr_id
|
||||
)
|
||||
current_movies_db = database.execute("SELECT tmdbId, path, radarrId FROM table_movies")
|
||||
|
||||
current_movies_db_list = [x.tmdb_id for x in current_movies_db]
|
||||
current_movies_db_list = [x['tmdbId'] for x in current_movies_db]
|
||||
|
||||
current_movies_radarr = []
|
||||
movies_to_update = []
|
||||
|
@ -139,31 +133,31 @@ def update_movies():
|
|||
current_movies_radarr.append(six.text_type(movie['tmdbId']))
|
||||
|
||||
if six.text_type(movie['tmdbId']) in current_movies_db_list:
|
||||
movies_to_update.append({'radarr_id': movie["id"],
|
||||
movies_to_update.append({'radarrId': movie["id"],
|
||||
'title': six.text_type(movie["title"]),
|
||||
'path': six.text_type(movie["path"] + separator + movie['movieFile']['relativePath']),
|
||||
'tmdb_id': six.text_type(movie["tmdbId"]),
|
||||
'tmdbId': six.text_type(movie["tmdbId"]),
|
||||
'poster': six.text_type(poster),
|
||||
'fanart': six.text_type(fanart),
|
||||
'audio_language': six.text_type(profile_id_to_language(movie['qualityProfileId'], audio_profiles)),
|
||||
'scene_name': sceneName,
|
||||
'sceneName': sceneName,
|
||||
'monitored': six.text_type(bool(movie['monitored'])),
|
||||
'year': six.text_type(movie['year']),
|
||||
'sort_title': six.text_type(movie['sortTitle']),
|
||||
'alternative_titles': six.text_type(alternativeTitles),
|
||||
'sortTitle': six.text_type(movie['sortTitle']),
|
||||
'alternativeTitles': six.text_type(alternativeTitles),
|
||||
'format': six.text_type(format),
|
||||
'resolution': six.text_type(resolution),
|
||||
'video_codec': six.text_type(videoCodec),
|
||||
'audio_codec': six.text_type(audioCodec),
|
||||
'overview': six.text_type(overview),
|
||||
'imdb_id': six.text_type(imdbId),
|
||||
'imdbId': six.text_type(imdbId),
|
||||
'movie_file_id': movie['movieFile']['id']})
|
||||
else:
|
||||
if movie_default_enabled is True:
|
||||
movies_to_add.append({'radarr_id': movie["id"],
|
||||
movies_to_add.append({'radarrId': movie["id"],
|
||||
'title': movie["title"],
|
||||
'path': movie["path"] + separator + movie['movieFile']['relativePath'],
|
||||
'tmdb_id': movie["tmdbId"],
|
||||
'tmdbId': movie["tmdbId"],
|
||||
'languages': movie_default_language,
|
||||
'subtitles': '[]',
|
||||
'hearing_impaired': movie_default_hi,
|
||||
|
@ -171,37 +165,41 @@ def update_movies():
|
|||
'poster': poster,
|
||||
'fanart': fanart,
|
||||
'audio_language': profile_id_to_language(movie['qualityProfileId'], audio_profiles),
|
||||
'scene_name': sceneName,
|
||||
'sceneName': sceneName,
|
||||
'monitored': six.text_type(bool(movie['monitored'])),
|
||||
'sort_title': movie['sortTitle'],
|
||||
'sortTitle': movie['sortTitle'],
|
||||
'year': movie['year'],
|
||||
'alternative_titles': alternativeTitles,
|
||||
'alternativeTitles': alternativeTitles,
|
||||
'format': format,
|
||||
'resolution': resolution,
|
||||
'video_codec': videoCodec,
|
||||
'audio_codec': audioCodec,
|
||||
'imdb_id': imdbId,
|
||||
'imdbId': imdbId,
|
||||
'forced': movie_default_forced,
|
||||
'movie_file_id': movie['movieFile']['id']})
|
||||
else:
|
||||
movies_to_add.append({'radarr_id': movie["id"],
|
||||
movies_to_add.append({'radarrId': movie["id"],
|
||||
'title': movie["title"],
|
||||
'path': movie["path"] + separator + movie['movieFile']['relativePath'],
|
||||
'tmdb_id': movie["tmdbId"],
|
||||
'tmdbId': movie["tmdbId"],
|
||||
'languages': None,
|
||||
'subtitles': '[]',
|
||||
'hearing_impaired': None,
|
||||
'overview': overview,
|
||||
'poster': poster,
|
||||
'fanart': fanart,
|
||||
'audio_language': profile_id_to_language(movie['qualityProfileId'], audio_profiles),
|
||||
'scene_name': sceneName,
|
||||
'sceneName': sceneName,
|
||||
'monitored': six.text_type(bool(movie['monitored'])),
|
||||
'sort_title': movie['sortTitle'],
|
||||
'sortTitle': movie['sortTitle'],
|
||||
'year': movie['year'],
|
||||
'alternative_titles': alternativeTitles,
|
||||
'alternativeTitles': alternativeTitles,
|
||||
'format': format,
|
||||
'resolution': resolution,
|
||||
'video_codec': videoCodec,
|
||||
'audio_codec': audioCodec,
|
||||
'imdb_id': imdbId,
|
||||
'imdbId': imdbId,
|
||||
'forced': None,
|
||||
'movie_file_id': movie['movieFile']['id']})
|
||||
else:
|
||||
logging.error(
|
||||
|
@ -210,27 +208,10 @@ def update_movies():
|
|||
|
||||
# Update or insert movies in DB
|
||||
movies_in_db_list = []
|
||||
movies_in_db = TableMovies.select(
|
||||
TableMovies.radarr_id,
|
||||
TableMovies.title,
|
||||
TableMovies.path,
|
||||
TableMovies.tmdb_id,
|
||||
TableMovies.overview,
|
||||
TableMovies.poster,
|
||||
TableMovies.fanart,
|
||||
TableMovies.audio_language,
|
||||
TableMovies.scene_name,
|
||||
TableMovies.monitored,
|
||||
TableMovies.sort_title,
|
||||
TableMovies.year,
|
||||
TableMovies.alternative_titles,
|
||||
TableMovies.format,
|
||||
TableMovies.resolution,
|
||||
TableMovies.video_codec,
|
||||
TableMovies.audio_codec,
|
||||
TableMovies.imdb_id,
|
||||
TableMovies.movie_file_id
|
||||
).dicts()
|
||||
movies_in_db = database.execute("SELECT radarrId, title, path, tmdbId, overview, poster, fanart, "
|
||||
"audio_language, sceneName, monitored, sortTitle, year, "
|
||||
"alternativeTitles, format, resolution, video_codec, audio_codec, imdbId,"
|
||||
"movie_file_id FROM table_movies")
|
||||
|
||||
for item in movies_in_db:
|
||||
movies_in_db_list.append(item)
|
||||
|
@ -238,31 +219,30 @@ def update_movies():
|
|||
movies_to_update_list = [i for i in movies_to_update if i not in movies_in_db_list]
|
||||
|
||||
for updated_movie in movies_to_update_list:
|
||||
TableMovies.update(
|
||||
updated_movie
|
||||
).where(
|
||||
TableMovies.radarr_id == updated_movie['radarr_id']
|
||||
).execute()
|
||||
altered_movies.append([updated_movie['tmdb_id'],
|
||||
query = dict_converter.convert(updated_movie)
|
||||
database.execute('''UPDATE table_movies SET ''' + query.keys_update + ''' WHERE radarrId = ?''',
|
||||
query.values + (updated_movie['radarrId'],))
|
||||
altered_movies.append([updated_movie['tmdbId'],
|
||||
updated_movie['path'],
|
||||
updated_movie['radarr_id']])
|
||||
updated_movie['radarrId'],
|
||||
updated_movie['monitored']])
|
||||
|
||||
# Insert new movies in DB
|
||||
for added_movie in movies_to_add:
|
||||
TableMovies.insert(
|
||||
added_movie
|
||||
).on_conflict_ignore().execute()
|
||||
altered_movies.append([added_movie['tmdb_id'],
|
||||
query = dict_converter.convert(added_movie)
|
||||
database.execute(
|
||||
'''INSERT OR IGNORE INTO table_movies(''' + query.keys_insert + ''') VALUES(''' +
|
||||
query.question_marks + ''')''', query.values)
|
||||
altered_movies.append([added_movie['tmdbId'],
|
||||
added_movie['path'],
|
||||
added_movie['radarr_id']])
|
||||
added_movie['radarrId'],
|
||||
added_movie['monitored']])
|
||||
|
||||
# Remove old movies from DB
|
||||
removed_movies = list(set(current_movies_db_list) - set(current_movies_radarr))
|
||||
|
||||
for removed_movie in removed_movies:
|
||||
TableMovies.delete().where(
|
||||
TableMovies.tmdb_id == removed_movie
|
||||
).execute()
|
||||
database.execute("DELETE FROM table_movies WHERE tmdbId=?", (removed_movie,))
|
||||
|
||||
# Store subtitles for added or modified movies
|
||||
for i, altered_movie in enumerate(altered_movies, 1):
|
||||
|
@ -276,7 +256,11 @@ def update_movies():
|
|||
if len(altered_movies) <= 5:
|
||||
logging.debug("BAZARR No more than 5 movies were added during this sync then we'll search for subtitles.")
|
||||
for altered_movie in altered_movies:
|
||||
movies_download_subtitles(altered_movie[2])
|
||||
if settings.radarr.getboolean('only_monitored'):
|
||||
if altered_movie[3] == 'True':
|
||||
movies_download_subtitles(altered_movie[2])
|
||||
else:
|
||||
movies_download_subtitles(altered_movie[2])
|
||||
else:
|
||||
logging.debug("BAZARR More than 5 movies were added during this sync then we wont search for subtitles.")
|
||||
|
||||
|
|
|
@ -12,7 +12,7 @@ import datetime
|
|||
from get_args import args
|
||||
from config import settings, url_sonarr
|
||||
from list_subtitles import list_missing_subtitles
|
||||
from database import TableShows
|
||||
from database import database, dict_converter
|
||||
from utils import get_sonarr_version
|
||||
import six
|
||||
|
||||
|
@ -50,11 +50,9 @@ def update_series():
|
|||
return
|
||||
else:
|
||||
# Get current shows in DB
|
||||
current_shows_db = TableShows.select(
|
||||
TableShows.tvdb_id
|
||||
)
|
||||
current_shows_db = database.execute("SELECT tvdbId FROM table_shows")
|
||||
|
||||
current_shows_db_list = [x.tvdb_id for x in current_shows_db]
|
||||
current_shows_db_list = [x['tvdbId'] for x in current_shows_db]
|
||||
current_shows_sonarr = []
|
||||
series_to_update = []
|
||||
series_to_add = []
|
||||
|
@ -88,59 +86,48 @@ def update_series():
|
|||
if show['tvdbId'] in current_shows_db_list:
|
||||
series_to_update.append({'title': six.text_type(show["title"]),
|
||||
'path': six.text_type(show["path"]),
|
||||
'tvdb_id': int(show["tvdbId"]),
|
||||
'sonarr_series_id': int(show["id"]),
|
||||
'tvdbId': int(show["tvdbId"]),
|
||||
'sonarrSeriesId': int(show["id"]),
|
||||
'overview': six.text_type(overview),
|
||||
'poster': six.text_type(poster),
|
||||
'fanart': six.text_type(fanart),
|
||||
'audio_language': six.text_type(profile_id_to_language((show['qualityProfileId'] if get_sonarr_version().startswith('2') else show['languageProfileId']), audio_profiles)),
|
||||
'sort_title': six.text_type(show['sortTitle']),
|
||||
'sortTitle': six.text_type(show['sortTitle']),
|
||||
'year': six.text_type(show['year']),
|
||||
'alternate_titles': six.text_type(alternateTitles)})
|
||||
'alternateTitles': six.text_type(alternateTitles)})
|
||||
else:
|
||||
if serie_default_enabled is True:
|
||||
series_to_add.append({'title': show["title"],
|
||||
'path': show["path"],
|
||||
'tvdb_id': show["tvdbId"],
|
||||
'tvdbId': show["tvdbId"],
|
||||
'languages': serie_default_language,
|
||||
'hearing_impaired': serie_default_hi,
|
||||
'sonarr_series_id': show["id"],
|
||||
'sonarrSeriesId': show["id"],
|
||||
'overview': overview,
|
||||
'poster': poster,
|
||||
'fanart': fanart,
|
||||
'audio_language': profile_id_to_language((show['qualityProfileId'] if sonarr_version.startswith('2') else show['languageProfileId']), audio_profiles),
|
||||
'sort_title': show['sortTitle'],
|
||||
'sortTitle': show['sortTitle'],
|
||||
'year': show['year'],
|
||||
'alternate_titles': alternateTitles,
|
||||
'alternateTitles': alternateTitles,
|
||||
'forced': serie_default_forced})
|
||||
else:
|
||||
series_to_add.append({'title': show["title"],
|
||||
'path': show["path"],
|
||||
'tvdb_id': show["tvdbId"],
|
||||
'sonarr_series_id': show["id"],
|
||||
'tvdbId': show["tvdbId"],
|
||||
'sonarrSeriesId': show["id"],
|
||||
'overview': overview,
|
||||
'poster': poster,
|
||||
'fanart': fanart,
|
||||
'audio_language': profile_id_to_language((show['qualityProfileId'] if sonarr_version.startswith('2') else show['languageProfileId']), audio_profiles),
|
||||
'sort_title': show['sortTitle'],
|
||||
'sortTitle': show['sortTitle'],
|
||||
'year': show['year'],
|
||||
'alternate_titles': alternateTitles})
|
||||
'alternateTitles': alternateTitles})
|
||||
|
||||
# Update existing series in DB
|
||||
series_in_db_list = []
|
||||
series_in_db = TableShows.select(
|
||||
TableShows.title,
|
||||
TableShows.path,
|
||||
TableShows.tvdb_id,
|
||||
TableShows.sonarr_series_id,
|
||||
TableShows.overview,
|
||||
TableShows.poster,
|
||||
TableShows.fanart,
|
||||
TableShows.audio_language,
|
||||
TableShows.sort_title,
|
||||
TableShows.year,
|
||||
TableShows.alternate_titles
|
||||
).dicts()
|
||||
series_in_db = database.execute("SELECT title, path, tvdbId, sonarrSeriesId, overview, poster, fanart, "
|
||||
"audio_language, sortTitle, year, alternateTitles FROM table_shows")
|
||||
|
||||
for item in series_in_db:
|
||||
series_in_db_list.append(item)
|
||||
|
@ -148,26 +135,23 @@ def update_series():
|
|||
series_to_update_list = [i for i in series_to_update if i not in series_in_db_list]
|
||||
|
||||
for updated_series in series_to_update_list:
|
||||
TableShows.update(
|
||||
updated_series
|
||||
).where(
|
||||
TableShows.sonarr_series_id == updated_series['sonarr_series_id']
|
||||
).execute()
|
||||
query = dict_converter.convert(updated_series)
|
||||
database.execute('''UPDATE table_shows SET ''' + query.keys_update + ''' WHERE sonarrSeriesId = ?''',
|
||||
query.values + (updated_series['sonarrSeriesId'],))
|
||||
|
||||
# Insert new series in DB
|
||||
for added_series in series_to_add:
|
||||
TableShows.insert(
|
||||
added_series
|
||||
).on_conflict_ignore().execute()
|
||||
list_missing_subtitles(no=added_series['sonarr_series_id'])
|
||||
query = dict_converter.convert(added_series)
|
||||
database.execute(
|
||||
'''INSERT OR IGNORE INTO table_shows(''' + query.keys_insert + ''') VALUES(''' +
|
||||
query.question_marks + ''')''', query.values)
|
||||
list_missing_subtitles(no=added_series['sonarrSeriesId'])
|
||||
|
||||
# Remove old series from DB
|
||||
removed_series = list(set(current_shows_db_list) - set(current_shows_sonarr))
|
||||
|
||||
for series in removed_series:
|
||||
TableShows.delete().where(
|
||||
TableShows.tvdb_id == series
|
||||
).execute()
|
||||
database.execute("DELETE FROM table_shows WHERE tvdbId=?",(series,))
|
||||
|
||||
logging.debug('BAZARR All series synced from Sonarr into database.')
|
||||
|
||||
|
|
|
@ -34,8 +34,7 @@ from get_providers import get_providers, get_providers_auth, provider_throttle,
|
|||
from get_args import args
|
||||
from queueconfig import notifications
|
||||
from pyprobe.pyprobe import VideoFileParser
|
||||
from database import database, TableShows, TableEpisodes, TableMovies, TableHistory, TableHistoryMovie
|
||||
from peewee import fn, JOIN
|
||||
from database import database, dict_mapper
|
||||
|
||||
from analytics import track_event
|
||||
import six
|
||||
|
@ -554,51 +553,37 @@ def manual_upload_subtitle(path, language, forced, title, scene_name, media_type
|
|||
|
||||
|
||||
def series_download_subtitles(no):
|
||||
episodes_details_clause = [
|
||||
(TableEpisodes.sonarr_series_id == no),
|
||||
(TableEpisodes.missing_subtitles != '[]')
|
||||
]
|
||||
if settings.sonarr.getboolean('only_monitored'):
|
||||
episodes_details_clause.append(
|
||||
(TableEpisodes.monitored == 'True')
|
||||
)
|
||||
episodes_details_clause = " AND monitored='True'"
|
||||
else:
|
||||
episodes_details_clause = ''
|
||||
|
||||
episodes_details = TableEpisodes.select(
|
||||
TableEpisodes.path,
|
||||
TableEpisodes.missing_subtitles,
|
||||
TableEpisodes.sonarr_episode_id,
|
||||
TableEpisodes.scene_name
|
||||
).where(
|
||||
reduce(operator.and_, episodes_details_clause)
|
||||
)
|
||||
episodes_details = database.execute("SELECT path, missing_subtitles, sonarrEpisodeId, scene_name "
|
||||
"FROM table_episodes WHERE sonarrSeriesId=? and missing_subtitles!='[]'" +
|
||||
episodes_details_clause, (no,))
|
||||
|
||||
series_details = TableShows.select(
|
||||
TableShows.hearing_impaired,
|
||||
TableShows.title,
|
||||
TableShows.forced
|
||||
).where(
|
||||
TableShows.sonarr_series_id == no
|
||||
).first()
|
||||
series_details = database.execute("SELECT hearing_impaired, title, forced FROM table_shows WHERE sonarrSeriesId=?",
|
||||
(no,), only_one=True)
|
||||
|
||||
providers_list = get_providers()
|
||||
providers_auth = get_providers_auth()
|
||||
|
||||
count_episodes_details = episodes_details.count()
|
||||
count_episodes_details = len(episodes_details)
|
||||
|
||||
for i, episode in enumerate(episodes_details, 1):
|
||||
if providers_list:
|
||||
for language in ast.literal_eval(episode.missing_subtitles):
|
||||
for language in ast.literal_eval(episode['missing_subtitles']):
|
||||
if language is not None:
|
||||
notifications.write(msg='Searching for Series Subtitles...', queue='get_subtitle', item=i,
|
||||
length=count_episodes_details)
|
||||
result = download_subtitle(path_replace(episode.path),
|
||||
str(alpha3_from_alpha2(language.split(':'))),
|
||||
series_details.hearing_impaired,
|
||||
result = download_subtitle(path_replace(episode['path']),
|
||||
str(alpha3_from_alpha2(language.split(':')[0])),
|
||||
series_details['hearing_impaired'],
|
||||
"True" if len(language.split(':')) > 1 else "False",
|
||||
providers_list,
|
||||
providers_auth,
|
||||
str(episode.scene_name),
|
||||
series_details.title,
|
||||
str(episode['scene_name']),
|
||||
series_details['title'],
|
||||
'series')
|
||||
if result is not None:
|
||||
message = result[0]
|
||||
|
@ -607,9 +592,9 @@ def series_download_subtitles(no):
|
|||
language_code = result[2] + ":forced" if forced else result[2]
|
||||
provider = result[3]
|
||||
score = result[4]
|
||||
store_subtitles(episode.path, path_replace(episode.path))
|
||||
history_log(1, no, episode.sonarr_episode_id, message, path, language_code, provider, score)
|
||||
send_notifications(no, episode.sonarr_episode_id, message)
|
||||
store_subtitles(episode['path'], path_replace(episode['path']))
|
||||
history_log(1, no, episode['sonarrEpisodeId'], message, path, language_code, provider, score)
|
||||
send_notifications(no, episode['sonarrEpisodeId'], message)
|
||||
else:
|
||||
notifications.write(msg='BAZARR All providers are throttled', queue='get_subtitle', duration='long')
|
||||
logging.info("BAZARR All providers are throttled")
|
||||
|
@ -621,47 +606,36 @@ def series_download_subtitles(no):
|
|||
|
||||
|
||||
def episode_download_subtitles(no):
|
||||
episodes_details_clause = [
|
||||
(TableEpisodes.sonarr_episode_id == no)
|
||||
]
|
||||
if settings.sonarr.getboolean('only_monitored'):
|
||||
episodes_details_clause.append(
|
||||
(TableEpisodes.monitored == 'True')
|
||||
)
|
||||
episodes_details_clause = " AND monitored='True'"
|
||||
else:
|
||||
episodes_details_clause = ''
|
||||
|
||||
episodes_details = TableEpisodes.select(
|
||||
TableEpisodes.path,
|
||||
TableEpisodes.missing_subtitles,
|
||||
TableEpisodes.sonarr_episode_id,
|
||||
TableEpisodes.scene_name,
|
||||
TableShows.hearing_impaired,
|
||||
TableShows.title,
|
||||
TableShows.sonarr_series_id,
|
||||
TableShows.forced
|
||||
).join_from(
|
||||
TableEpisodes, TableShows, JOIN.LEFT_OUTER
|
||||
).where(
|
||||
reduce(operator.and_, episodes_details_clause)
|
||||
)
|
||||
episodes_details = database.execute("SELECT table_episodes.path, table_episodes.missing_subtitles, "
|
||||
"table_episodes.sonarrEpisodeId, table_episodes.scene_name, "
|
||||
"table_shows.hearing_impaired, table_shows.title, table_shows.sonarrSeriesId, "
|
||||
"table_shows.forced FROM table_episodes LEFT JOIN table_shows on "
|
||||
"table_episodes.sonarrSeriesId = table_shows.sonarrSeriesId "
|
||||
"WHERE sonarrEpisodeId=?" + episodes_details_clause, (no,))
|
||||
|
||||
providers_list = get_providers()
|
||||
providers_auth = get_providers_auth()
|
||||
|
||||
for episode in episodes_details:
|
||||
if providers_list:
|
||||
for language in ast.literal_eval(episode.missing_subtitles):
|
||||
for language in ast.literal_eval(episode['missing_subtitles']):
|
||||
if language is not None:
|
||||
notifications.write(msg='Searching for ' + str(
|
||||
language_from_alpha2(language)) + ' Subtitles for this episode: ' + path_replace(episode.path),
|
||||
queue='get_subtitle')
|
||||
result = download_subtitle(path_replace(episode.path),
|
||||
language_from_alpha2(language)) + ' Subtitles for this episode: ' +
|
||||
path_replace(episode['path']), queue='get_subtitle')
|
||||
result = download_subtitle(path_replace(episode['path']),
|
||||
str(alpha3_from_alpha2(language.split(':')[0])),
|
||||
episode.hearing_impaired,
|
||||
episode['hearing_impaired'],
|
||||
"True" if len(language.split(':')) > 1 else "False",
|
||||
providers_list,
|
||||
providers_auth,
|
||||
str(episode.scene_name),
|
||||
episode.title,
|
||||
str(episode['scene_name']),
|
||||
episode['title'],
|
||||
'series')
|
||||
if result is not None:
|
||||
message = result[0]
|
||||
|
@ -670,9 +644,9 @@ def episode_download_subtitles(no):
|
|||
language_code = result[2] + ":forced" if forced else result[2]
|
||||
provider = result[3]
|
||||
score = result[4]
|
||||
store_subtitles(episode.path, path_replace(episode.path))
|
||||
history_log(1, episode.sonarr_series_id, episode.sonarr_episode_id, message, path, language_code, provider, score)
|
||||
send_notifications(episode.sonarr_series_id, episode.sonarr_episode_id, message)
|
||||
store_subtitles(episode['path'], path_replace(episode['path']))
|
||||
history_log(1, episode['sonarrSeriesId'], episode['sonarrEpisodeId'], message, path, language_code, provider, score)
|
||||
send_notifications(episode['sonarrSeriesId'], episode['sonarrEpisodeId'], message)
|
||||
else:
|
||||
notifications.write(msg='BAZARR All providers are throttled', queue='get_subtitle', duration='long')
|
||||
logging.info("BAZARR All providers are throttled")
|
||||
|
@ -680,36 +654,35 @@ def episode_download_subtitles(no):
|
|||
|
||||
|
||||
def movies_download_subtitles(no):
|
||||
movie = TableMovies.select(
|
||||
TableMovies.path,
|
||||
TableMovies.missing_subtitles,
|
||||
TableMovies.radarr_id,
|
||||
TableMovies.scene_name,
|
||||
TableMovies.hearing_impaired,
|
||||
TableMovies.title,
|
||||
TableMovies.forced
|
||||
).where(
|
||||
TableMovies.radarr_id == no
|
||||
).first()
|
||||
if settings.radarr.getboolean('only_monitored'):
|
||||
movie_details_clause = " AND monitored='True'"
|
||||
else:
|
||||
movie_details_clause = ''
|
||||
|
||||
movie = database.execute("SELECT path, missing_subtitles, radarrId, sceneName, hearing_impaired, title, forced "
|
||||
"FROM table_movies WHERE radarrId=?" + movie_details_clause, (no,), only_one=True)
|
||||
|
||||
providers_list = get_providers()
|
||||
providers_auth = get_providers_auth()
|
||||
|
||||
if ast.literal_eval(movie['missing_subtitles']):
|
||||
count_movie = len(ast.literal_eval(movie['missing_subtitles']))
|
||||
else:
|
||||
count_movie = 0
|
||||
|
||||
count_movie = len(ast.literal_eval(movie.missing_subtitles))
|
||||
|
||||
for i, language in enumerate(ast.literal_eval(movie.missing_subtitles), 1):
|
||||
for i, language in enumerate(ast.literal_eval(movie['missing_subtitles']), 1):
|
||||
if providers_list:
|
||||
if language is not None:
|
||||
notifications.write(msg='Searching for Movie Subtitles', queue='get_subtitle', item=i,
|
||||
length=count_movie)
|
||||
result = download_subtitle(path_replace_movie(movie.path),
|
||||
result = download_subtitle(path_replace_movie(movie['path']),
|
||||
str(alpha3_from_alpha2(language.split(':')[0])),
|
||||
movie.hearing_impaired,
|
||||
movie['hearing_impaired'],
|
||||
"True" if len(language.split(':')) > 1 else "False",
|
||||
providers_list,
|
||||
providers_auth,
|
||||
str(movie.scene_name),
|
||||
movie.title,
|
||||
str(movie['sceneName']),
|
||||
movie['title'],
|
||||
'movie')
|
||||
if result is not None:
|
||||
message = result[0]
|
||||
|
@ -718,7 +691,7 @@ def movies_download_subtitles(no):
|
|||
language_code = result[2] + ":forced" if forced else result[2]
|
||||
provider = result[3]
|
||||
score = result[4]
|
||||
store_subtitles_movie(movie.path, path_replace_movie(movie.path))
|
||||
store_subtitles_movie(movie['path'], path_replace_movie(movie['path']))
|
||||
history_log_movie(1, no, message, path, language_code, provider, score)
|
||||
send_notifications_movie(no, message)
|
||||
else:
|
||||
|
@ -732,32 +705,23 @@ def movies_download_subtitles(no):
|
|||
|
||||
|
||||
def wanted_download_subtitles(path, l, count_episodes):
|
||||
|
||||
episodes_details = TableEpisodes.select(
|
||||
TableEpisodes.path,
|
||||
TableEpisodes.missing_subtitles,
|
||||
TableEpisodes.sonarr_episode_id,
|
||||
TableEpisodes.sonarr_series_id,
|
||||
TableShows.hearing_impaired,
|
||||
TableEpisodes.scene_name,
|
||||
TableEpisodes.failed_attempts,
|
||||
TableShows.title,
|
||||
TableShows.forced
|
||||
).join_from(
|
||||
TableEpisodes, TableShows, JOIN.LEFT_OUTER
|
||||
).where(
|
||||
(TableEpisodes.path == path_replace_reverse(path)) &
|
||||
(TableEpisodes.missing_subtitles != '[]')
|
||||
).objects()
|
||||
episodes_details = database.execute("SELECT table_episodes.path, table_episodes.missing_subtitles, "
|
||||
"table_episodes.sonarrEpisodeId, table_episodes.sonarrSeriesId, "
|
||||
"table_shows.hearing_impaired, table_episodes.scene_name,"
|
||||
"table_episodes.failedAttempts, table_shows.title, table_shows.forced "
|
||||
"FROM table_episodes LEFT JOIN table_shows on "
|
||||
"table_episodes.sonarrSeriesId = table_shows.sonarrSeriesId "
|
||||
"WHERE table_episodes.path=? and table_episodes.missing_subtitles!='[]'",
|
||||
(path_replace_reverse(path),))
|
||||
|
||||
providers_list = get_providers()
|
||||
providers_auth = get_providers_auth()
|
||||
|
||||
for episode in episodes_details:
|
||||
attempt = episode.failed_attempts
|
||||
attempt = episode['failedAttempts']
|
||||
if type(attempt) == six.text_type:
|
||||
attempt = ast.literal_eval(attempt)
|
||||
for language in ast.literal_eval(episode.missing_subtitles):
|
||||
for language in ast.literal_eval(episode['missing_subtitles']):
|
||||
if attempt is None:
|
||||
attempt = []
|
||||
attempt.append([language, time.time()])
|
||||
|
@ -766,27 +730,22 @@ def wanted_download_subtitles(path, l, count_episodes):
|
|||
if language not in att:
|
||||
attempt.append([language, time.time()])
|
||||
|
||||
TableEpisodes.update(
|
||||
{
|
||||
TableEpisodes.failed_attempts: six.text_type(attempt)
|
||||
}
|
||||
).where(
|
||||
TableEpisodes.sonarr_episode_id == episode.sonarr_episode_id
|
||||
).execute()
|
||||
database.execute("UPDATE table_episodes SET failedAttempts=? WHERE sonarrEpisodeId=?",
|
||||
(six.text_type(attempt), episode['sonarrEpisodeId']))
|
||||
|
||||
for i in range(len(attempt)):
|
||||
if attempt[i][0] == language:
|
||||
if search_active(attempt[i][1]):
|
||||
notifications.write(msg='Searching for Series Subtitles...', queue='get_subtitle', item=l,
|
||||
length=count_episodes)
|
||||
result = download_subtitle(path_replace(episode.path),
|
||||
result = download_subtitle(path_replace(episode['path']),
|
||||
str(alpha3_from_alpha2(language.split(':')[0])),
|
||||
episode.hearing_impaired,
|
||||
episode['hearing_impaired'],
|
||||
"True" if len(language.split(':')) > 1 else "False",
|
||||
providers_list,
|
||||
providers_auth,
|
||||
str(episode.scene_name),
|
||||
episode.title,
|
||||
str(episode['scene_name']),
|
||||
episode['title'],
|
||||
'series')
|
||||
if result is not None:
|
||||
message = result[0]
|
||||
|
@ -795,37 +754,27 @@ def wanted_download_subtitles(path, l, count_episodes):
|
|||
language_code = result[2] + ":forced" if forced else result[2]
|
||||
provider = result[3]
|
||||
score = result[4]
|
||||
store_subtitles(episode.path, path_replace(episode.path))
|
||||
history_log(1, episode.sonarr_series_id.sonarr_series_id, episode.sonarr_episode_id, message, path, language_code, provider, score)
|
||||
send_notifications(episode.sonarr_series_id.sonarr_series_id, episode.sonarr_episode_id, message)
|
||||
store_subtitles(episode['path'], path_replace(episode['path']))
|
||||
history_log(1, episode['sonarrSeriesId'], episode['sonarrEpisodeId'], message, path, language_code, provider, score)
|
||||
send_notifications(episode['sonarrSeriesId'], episode['sonarrEpisodeId'], message)
|
||||
else:
|
||||
logging.debug(
|
||||
'BAZARR Search is not active for episode ' + episode.path + ' Language: ' + attempt[i][0])
|
||||
'BAZARR Search is not active for episode ' + episode['path'] + ' Language: ' + attempt[i][0])
|
||||
|
||||
|
||||
def wanted_download_subtitles_movie(path, l, count_movies):
|
||||
movies_details = TableMovies.select(
|
||||
TableMovies.path,
|
||||
TableMovies.missing_subtitles,
|
||||
TableMovies.radarr_id,
|
||||
TableMovies.hearing_impaired,
|
||||
TableMovies.scene_name,
|
||||
TableMovies.failed_attempts,
|
||||
TableMovies.title,
|
||||
TableMovies.forced
|
||||
).where(
|
||||
(TableMovies.path == path_replace_reverse_movie(path)) &
|
||||
(TableMovies.missing_subtitles != '[]')
|
||||
)
|
||||
movies_details = database.execute("SELECT path, missing_subtitles, radarrId, hearing_impaired, sceneName, "
|
||||
"failedAttempts, title, forced FROM table_movies WHERE path = ? "
|
||||
"AND missing_subtitles != '[]'", (path_replace_reverse_movie(path),))
|
||||
|
||||
providers_list = get_providers()
|
||||
providers_auth = get_providers_auth()
|
||||
|
||||
for movie in movies_details:
|
||||
attempt = movie.failed_attempts
|
||||
attempt = movie['failedAttempts']
|
||||
if type(attempt) == six.text_type:
|
||||
attempt = ast.literal_eval(attempt)
|
||||
for language in ast.literal_eval(movie.missing_subtitles):
|
||||
for language in ast.literal_eval(movie['missing_subtitles']):
|
||||
if attempt is None:
|
||||
attempt = []
|
||||
attempt.append([language, time.time()])
|
||||
|
@ -834,27 +783,22 @@ def wanted_download_subtitles_movie(path, l, count_movies):
|
|||
if language not in att:
|
||||
attempt.append([language, time.time()])
|
||||
|
||||
TableMovies.update(
|
||||
{
|
||||
TableMovies.failed_attempts: six.text_type(attempt)
|
||||
}
|
||||
).where(
|
||||
TableMovies.radarr_id == movie.radarr_id
|
||||
).execute()
|
||||
database.execute("UPDATE table_movies SET failedAttempts=? WHERE radarrId=?",
|
||||
(six.text_type(attempt), movie['radarrId']))
|
||||
|
||||
for i in range(len(attempt)):
|
||||
if attempt[i][0] == language:
|
||||
if search_active(attempt[i][1]) is True:
|
||||
notifications.write(msg='Searching for Movie Subtitles...', queue='get_subtitle', item=l,
|
||||
length=count_movies)
|
||||
result = download_subtitle(path_replace_movie(movie.path),
|
||||
result = download_subtitle(path_replace_movie(movie['path']),
|
||||
str(alpha3_from_alpha2(language.split(':')[0])),
|
||||
movie.hearing_impaired,
|
||||
movie['hearing_impaired'],
|
||||
"True" if len(language.split(':')) > 1 else "False",
|
||||
providers_list,
|
||||
providers_auth,
|
||||
str(movie.scene_name),
|
||||
movie.title,
|
||||
str(movie['sceneName']),
|
||||
movie['title'],
|
||||
'movie')
|
||||
if result is not None:
|
||||
message = result[0]
|
||||
|
@ -863,59 +807,52 @@ def wanted_download_subtitles_movie(path, l, count_movies):
|
|||
language_code = result[2] + ":forced" if forced else result[2]
|
||||
provider = result[3]
|
||||
score = result[4]
|
||||
store_subtitles_movie(movie.path, path_replace_movie(movie.path))
|
||||
history_log_movie(1, movie.radarr_id, message, path, language_code, provider, score)
|
||||
send_notifications_movie(movie.radarr_id, message)
|
||||
store_subtitles_movie(movie['path'], path_replace_movie(movie['path']))
|
||||
history_log_movie(1, movie['radarrId'], message, path, language_code, provider, score)
|
||||
send_notifications_movie(movie['radarrId'], message)
|
||||
else:
|
||||
logging.info(
|
||||
'BAZARR Search is not active for this Movie ' + movie.path + ' Language: ' + attempt[i][0])
|
||||
'BAZARR Search is not active for this Movie ' + movie['path'] + ' Language: ' + attempt[i][0])
|
||||
|
||||
|
||||
def wanted_search_missing_subtitles():
|
||||
if settings.general.getboolean('use_sonarr'):
|
||||
episodes_clause = [
|
||||
(TableEpisodes.missing_subtitles != '[]')
|
||||
]
|
||||
if settings.sonarr.getboolean('only_monitored'):
|
||||
episodes_clause.append(
|
||||
(TableEpisodes.monitored == 'True')
|
||||
)
|
||||
monitored_only_query_string_sonarr = ' AND monitored = "True"'
|
||||
else:
|
||||
monitored_only_query_string_sonarr = ""
|
||||
|
||||
episodes = TableEpisodes.select(
|
||||
fn.path_substitution(TableEpisodes.path).alias('path')
|
||||
).where(
|
||||
reduce(operator.and_, episodes_clause)
|
||||
)
|
||||
episodes = database.execute("SELECT path FROM table_episodes WHERE missing_subtitles != '[]'" +
|
||||
monitored_only_query_string_sonarr)
|
||||
# path_replace
|
||||
dict_mapper.path_replace(episodes)
|
||||
|
||||
count_episodes = episodes.count()
|
||||
count_episodes = len(episodes)
|
||||
for i, episode in enumerate(episodes, 1):
|
||||
providers = get_providers()
|
||||
if providers:
|
||||
wanted_download_subtitles(episode.path, i, count_episodes)
|
||||
wanted_download_subtitles(episode['path'], i, count_episodes)
|
||||
else:
|
||||
notifications.write(msg='BAZARR All providers are throttled', queue='get_subtitle', duration='long')
|
||||
logging.info("BAZARR All providers are throttled")
|
||||
return
|
||||
|
||||
if settings.general.getboolean('use_radarr'):
|
||||
movies_clause = [
|
||||
(TableMovies.missing_subtitles != '[]')
|
||||
]
|
||||
if settings.radarr.getboolean('only_monitored'):
|
||||
movies_clause.append(
|
||||
(TableMovies.monitored == 'True')
|
||||
)
|
||||
movies = TableMovies.select(
|
||||
fn.path_substitution_movie(TableMovies.path).alias('path')
|
||||
).where(
|
||||
reduce(operator.and_, movies_clause)
|
||||
)
|
||||
monitored_only_query_string_radarr = ' AND monitored = "True"'
|
||||
else:
|
||||
monitored_only_query_string_radarr = ""
|
||||
|
||||
count_movies = movies.count()
|
||||
movies = database.execute("SELECT path FROM table_movies WHERE missing_subtitles != '[]'" +
|
||||
monitored_only_query_string_radarr)
|
||||
# path_replace
|
||||
dict_mapper.path_replace_movie(movies)
|
||||
|
||||
count_movies = len(movies)
|
||||
for i, movie in enumerate(movies, 1):
|
||||
providers = get_providers()
|
||||
if providers:
|
||||
wanted_download_subtitles_movie(movie.path, i, count_movies)
|
||||
wanted_download_subtitles_movie(movie['path'], i, count_movies)
|
||||
else:
|
||||
notifications.write(msg='BAZARR All providers are throttled', queue='get_subtitle', duration='long')
|
||||
logging.info("BAZARR All providers are throttled")
|
||||
|
@ -948,70 +885,50 @@ def search_active(timestamp):
|
|||
|
||||
def refine_from_db(path, video):
|
||||
if isinstance(video, Episode):
|
||||
data = TableEpisodes.select(
|
||||
TableShows.title.alias('seriesTitle'),
|
||||
TableEpisodes.season,
|
||||
TableEpisodes.episode,
|
||||
TableEpisodes.title.alias('episodeTitle'),
|
||||
TableShows.year,
|
||||
TableShows.tvdb_id,
|
||||
TableShows.alternate_titles,
|
||||
TableEpisodes.format,
|
||||
TableEpisodes.resolution,
|
||||
TableEpisodes.video_codec,
|
||||
TableEpisodes.audio_codec,
|
||||
TableEpisodes.path
|
||||
).join_from(
|
||||
TableEpisodes, TableShows, JOIN.LEFT_OUTER
|
||||
).where(
|
||||
TableEpisodes.path == path_replace_reverse(path)
|
||||
).objects().first()
|
||||
data = database.execute("SELECT table_shows.title as seriesTitle, table_episodes.season, table_episodes.episode, "
|
||||
"table_episodes.title as episodeTitle, table_shows.year, table_shows.tvdbId, "
|
||||
"table_shows.alternateTitles, table_episodes.format, table_episodes.resolution, "
|
||||
"table_episodes.video_codec, table_episodes.audio_codec, table_episodes.path "
|
||||
"FROM table_episodes INNER JOIN table_shows on "
|
||||
"table_shows.sonarrSeriesId = table_episodes.sonarrSeriesId "
|
||||
"WHERE table_episodes.path = ?", (unicode(path_replace_reverse(path)),), only_one=True)
|
||||
|
||||
if data:
|
||||
video.series, year, country = series_re.match(data.seriesTitle).groups()
|
||||
video.season = int(data.season)
|
||||
video.episode = int(data.episode)
|
||||
video.title = data.episodeTitle
|
||||
if data.year:
|
||||
if int(data.year) > 0: video.year = int(data.year)
|
||||
video.series_tvdb_id = int(data.tvdb_id)
|
||||
video.alternative_series = ast.literal_eval(data.alternate_titles)
|
||||
video.series, year, country = series_re.match(data['seriesTitle']).groups()
|
||||
video.season = int(data['season'])
|
||||
video.episode = int(data['episode'])
|
||||
video.title = data['episodeTitle']
|
||||
if data['year']:
|
||||
if int(data['year']) > 0: video.year = int(data['year'])
|
||||
video.series_tvdb_id = int(data['tvdbId'])
|
||||
video.alternative_series = ast.literal_eval(data['alternateTitles'])
|
||||
if not video.format:
|
||||
video.format = str(data.format)
|
||||
video.format = str(data['format'])
|
||||
if not video.resolution:
|
||||
video.resolution = str(data.resolution)
|
||||
video.resolution = str(data['resolution'])
|
||||
if not video.video_codec:
|
||||
if data.video_codec: video.video_codec = data.video_codec
|
||||
if data['video_codec']: video.video_codec = data['video_codec']
|
||||
if not video.audio_codec:
|
||||
if data.audio_codec: video.audio_codec = data.audio_codec
|
||||
if data['audio_codec']: video.audio_codec = data['audio_codec']
|
||||
elif isinstance(video, Movie):
|
||||
data = TableMovies.select(
|
||||
TableMovies.title,
|
||||
TableMovies.year,
|
||||
TableMovies.alternative_titles,
|
||||
TableMovies.format,
|
||||
TableMovies.resolution,
|
||||
TableMovies.video_codec,
|
||||
TableMovies.audio_codec,
|
||||
TableMovies.imdb_id
|
||||
).where(
|
||||
TableMovies.path == six.text_type(path_replace_reverse_movie(path))
|
||||
).first()
|
||||
data = database.execute("SELECT title, year, alternativeTitles, format, resolution, video_codec, audio_codec, "
|
||||
"imdbId FROM table_movies WHERE path = ?",
|
||||
(text_type(path_replace_reverse_movie(path)),), only_one=True)
|
||||
|
||||
if data:
|
||||
video.title = re.sub(r'(\(\d\d\d\d\))', '', data.title)
|
||||
if data.year:
|
||||
if int(data.year) > 0: video.year = int(data.year)
|
||||
if data.imdb_id: video.imdb_id = data.imdb_id
|
||||
video.alternative_titles = ast.literal_eval(data.alternative_titles)
|
||||
video.title = re.sub(r'(\(\d\d\d\d\))', '', data['title'])
|
||||
if data['year']:
|
||||
if int(data['year']) > 0: video.year = int(data['year'])
|
||||
if data['imdbId']: video.imdb_id = data['imdbId']
|
||||
video.alternative_titles = ast.literal_eval(data['alternativeTitles'])
|
||||
if not video.format:
|
||||
if data.format: video.format = data.format
|
||||
if data['format']: video.format = data['format']
|
||||
if not video.resolution:
|
||||
if data.resolution: video.resolution = data.resolution
|
||||
if data['resolution']: video.resolution = data['resolution']
|
||||
if not video.video_codec:
|
||||
if data.video_codec: video.video_codec = data.video_codec
|
||||
if data['video_codec']: video.video_codec = data['video_codec']
|
||||
if not video.audio_codec:
|
||||
if data.audio_codec: video.audio_codec = data.audio_codec
|
||||
if data['audio_codec']: video.audio_codec = data['audio_codec']
|
||||
|
||||
return video
|
||||
|
||||
|
@ -1062,58 +979,38 @@ def upgrade_subtitles():
|
|||
minimum_timestamp = ((datetime.now() - timedelta(days=int(days_to_upgrade_subs))) -
|
||||
datetime(1970, 1, 1)).total_seconds()
|
||||
|
||||
if settings.sonarr.getboolean('only_monitored'):
|
||||
series_monitored_only_query_string = ' AND table_episodes.monitored = "True"'
|
||||
else:
|
||||
series_monitored_only_query_string = ""
|
||||
|
||||
if settings.radarr.getboolean('only_monitored'):
|
||||
movies_monitored_only_query_string = ' AND table_movies.monitored = "True"'
|
||||
else:
|
||||
movies_monitored_only_query_string = ""
|
||||
|
||||
if settings.general.getboolean('upgrade_manual'):
|
||||
query_actions = [1, 2, 3]
|
||||
else:
|
||||
query_actions = [1, 3]
|
||||
|
||||
episodes_details_clause = [
|
||||
(TableHistory.action.in_(query_actions)) &
|
||||
(TableHistory.score.is_null(False))
|
||||
]
|
||||
|
||||
if settings.sonarr.getboolean('only_monitored'):
|
||||
episodes_details_clause.append(
|
||||
(TableEpisodes.monitored == 'True')
|
||||
)
|
||||
|
||||
movies_details_clause = [
|
||||
(TableHistoryMovie.action.in_(query_actions)) &
|
||||
(TableHistoryMovie.score.is_null(False))
|
||||
]
|
||||
|
||||
if settings.radarr.getboolean('only_monitored'):
|
||||
movies_details_clause.append(
|
||||
(TableMovies.monitored == 'True')
|
||||
)
|
||||
|
||||
|
||||
if settings.general.getboolean('use_sonarr'):
|
||||
upgradable_episodes = TableHistory.select(
|
||||
TableHistory.video_path,
|
||||
TableHistory.language,
|
||||
TableHistory.score,
|
||||
TableShows.hearing_impaired,
|
||||
TableEpisodes.scene_name,
|
||||
TableEpisodes.title,
|
||||
TableEpisodes.sonarr_series_id,
|
||||
TableEpisodes.sonarr_episode_id,
|
||||
fn.MAX(TableHistory.timestamp).alias('timestamp'),
|
||||
TableShows.languages,
|
||||
TableShows.forced
|
||||
).join_from(
|
||||
TableHistory, TableShows, JOIN.LEFT_OUTER
|
||||
).join_from(
|
||||
TableHistory, TableEpisodes, JOIN.LEFT_OUTER
|
||||
).where(
|
||||
reduce(operator.and_, episodes_details_clause)
|
||||
).group_by(
|
||||
TableHistory.video_path,
|
||||
TableHistory.language
|
||||
).objects()
|
||||
upgradable_episodes = database.execute("SELECT table_history.video_path, table_history.language, "
|
||||
"table_history.score, table_shows.hearing_impaired, "
|
||||
"table_episodes.scene_name, table_episodes.title,"
|
||||
"table_episodes.sonarrSeriesId, table_episodes.sonarrEpisodeId,"
|
||||
"MAX(table_history.timestamp), table_shows.languages, table_shows.forced "
|
||||
"FROM table_history INNER JOIN table_shows on "
|
||||
"table_shows.sonarrSeriesId = table_history.sonarrSeriesId INNER JOIN "
|
||||
"table_episodes on table_episodes.sonarrEpisodeId = "
|
||||
"table_history.sonarrEpisodeId WHERE action IN "
|
||||
"(" + ','.join(map(str, query_actions)) + ") AND timestamp > ? AND "
|
||||
"score is not null" + series_monitored_only_query_string +
|
||||
"GROUP BY table_history.video_path, table_history.language",
|
||||
(minimum_timestamp,))
|
||||
|
||||
upgradable_episodes_not_perfect = []
|
||||
for upgradable_episode in upgradable_episodes.dicts():
|
||||
for upgradable_episode in upgradable_episodes:
|
||||
if upgradable_episode['timestamp'] > minimum_timestamp:
|
||||
try:
|
||||
int(upgradable_episode['score'])
|
||||
|
@ -1131,28 +1028,20 @@ def upgrade_subtitles():
|
|||
count_episode_to_upgrade = len(episodes_to_upgrade)
|
||||
|
||||
if settings.general.getboolean('use_radarr'):
|
||||
upgradable_movies = TableHistoryMovie.select(
|
||||
TableHistoryMovie.video_path,
|
||||
TableHistoryMovie.language,
|
||||
TableHistoryMovie.score,
|
||||
TableMovies.hearing_impaired,
|
||||
TableMovies.scene_name,
|
||||
TableMovies.title,
|
||||
TableMovies.radarr_id,
|
||||
fn.MAX(TableHistoryMovie.timestamp).alias('timestamp'),
|
||||
TableMovies.languages,
|
||||
TableMovies.forced
|
||||
).join_from(
|
||||
TableHistoryMovie, TableMovies, JOIN.LEFT_OUTER
|
||||
).where(
|
||||
reduce(operator.and_, movies_details_clause)
|
||||
).group_by(
|
||||
TableHistoryMovie.video_path,
|
||||
TableHistoryMovie.language
|
||||
).objects()
|
||||
upgradable_movies = database.execute("SELECT table_history_movie.video_path, table_history_movie.language, "
|
||||
"table_history_movie.score, table_movies.hearing_impaired, "
|
||||
"table_movies.sceneName, table_movies.title, table_movies.radarrId, "
|
||||
"MAX(table_history_movie.timestamp), table_movies.languages, "
|
||||
"table_movies.forced FROM table_history_movie INNER JOIN "
|
||||
"table_movies on table_movies.radarrId = table_history_movie.radarrId "
|
||||
"WHERE action IN (" + ','.join(map(str, query_actions)) +
|
||||
") AND timestamp > ? AND score is not null" +
|
||||
movies_monitored_only_query_string +
|
||||
" GROUP BY table_history_movie.video_path, table_history_movie.language",
|
||||
(minimum_timestamp,))
|
||||
|
||||
upgradable_movies_not_perfect = []
|
||||
for upgradable_movie in upgradable_movies.dicts():
|
||||
for upgradable_movie in upgradable_movies:
|
||||
if upgradable_movie['timestamp'] > minimum_timestamp:
|
||||
try:
|
||||
int(upgradable_movie['score'])
|
||||
|
@ -1218,8 +1107,8 @@ def upgrade_subtitles():
|
|||
provider = result[3]
|
||||
score = result[4]
|
||||
store_subtitles(episode['video_path'], path_replace(episode['video_path']))
|
||||
history_log(3, episode['sonarr_series_id'], episode['sonarr_episode_id'], message, path, language_code, provider, score)
|
||||
send_notifications(episode['sonarr_series_id'], episode['sonarr_episode_id'], message)
|
||||
history_log(3, episode['sonarrSeriesId'], episode['sonarrEpisodeId'], message, path, language_code, provider, score)
|
||||
send_notifications(episode['sonarrSeriesId'], episode['sonarrEpisodeId'], message)
|
||||
|
||||
if settings.general.getboolean('use_radarr'):
|
||||
for i, movie in enumerate(movies_to_upgrade, 1):
|
||||
|
@ -1254,7 +1143,7 @@ def upgrade_subtitles():
|
|||
is_forced,
|
||||
providers_list,
|
||||
providers_auth,
|
||||
str(movie['scene_name']),
|
||||
str(movie['sceneName']),
|
||||
movie['title'],
|
||||
'movie',
|
||||
forced_minimum_score=int(movie['score']),
|
||||
|
@ -1267,5 +1156,5 @@ def upgrade_subtitles():
|
|||
provider = result[3]
|
||||
score = result[4]
|
||||
store_subtitles_movie(movie['video_path'], path_replace_movie(movie['video_path']))
|
||||
history_log_movie(3, movie['radarr_id'], message, path, language_code, provider, score)
|
||||
send_notifications_movie(movie['radarr_id'], message)
|
||||
history_log_movie(3, movie['radarrId'], message, path, language_code, provider, score)
|
||||
send_notifications_movie(movie['radarrId'], message)
|
||||
|
|
|
@ -9,9 +9,7 @@ import rarfile
|
|||
from cork import Cork
|
||||
from backports import configparser2
|
||||
from config import settings
|
||||
from check_update import check_releases
|
||||
from get_args import args
|
||||
from utils import get_binary
|
||||
|
||||
from dogpile.cache.region import register_backend as register_cache_backend
|
||||
import subliminal
|
||||
|
@ -55,6 +53,27 @@ if not os.path.exists(os.path.join(args.config_dir, 'cache')):
|
|||
os.mkdir(os.path.join(args.config_dir, 'cache'))
|
||||
logging.debug("BAZARR Created cache folder")
|
||||
|
||||
# create database file
|
||||
if not os.path.exists(os.path.join(args.config_dir, 'db', 'bazarr.db')):
|
||||
import sqlite3
|
||||
# Get SQL script from file
|
||||
fd = open(os.path.join(os.path.dirname(__file__), 'create_db.sql'), 'r')
|
||||
script = fd.read()
|
||||
# Close SQL script file
|
||||
fd.close()
|
||||
# Open database connection
|
||||
db = sqlite3.connect(os.path.join(args.config_dir, 'db', 'bazarr.db'), timeout=30)
|
||||
c = db.cursor()
|
||||
# Execute script and commit change to database
|
||||
c.executescript(script)
|
||||
# Close database connection
|
||||
db.close()
|
||||
logging.info('BAZARR Database created successfully')
|
||||
|
||||
# upgrade database schema
|
||||
from database import db_upgrade
|
||||
db_upgrade()
|
||||
|
||||
# Configure dogpile file caching for Subliminal request
|
||||
register_cache_backend("subzero.cache.file", "subzero.cache_backends.file", "SZFileBackend")
|
||||
subliminal.region.configure('subzero.cache.file', expiration_time=datetime.timedelta(days=30),
|
||||
|
@ -62,6 +81,7 @@ subliminal.region.configure('subzero.cache.file', expiration_time=datetime.timed
|
|||
subliminal.region.backend.sync()
|
||||
|
||||
if not os.path.exists(os.path.join(args.config_dir, 'config', 'releases.txt')):
|
||||
from check_update import check_releases
|
||||
check_releases()
|
||||
logging.debug("BAZARR Created releases file")
|
||||
|
||||
|
@ -88,6 +108,7 @@ if not os.path.exists(os.path.normpath(os.path.join(args.config_dir, 'config', '
|
|||
|
||||
|
||||
def init_binaries():
|
||||
from utils import get_binary
|
||||
exe = get_binary("unrar")
|
||||
|
||||
rarfile.UNRAR_TOOL = exe
|
||||
|
@ -100,7 +121,7 @@ def init_binaries():
|
|||
rarfile.OPEN_ARGS = rarfile.ORIG_OPEN_ARGS
|
||||
rarfile.EXTRACT_ARGS = rarfile.ORIG_EXTRACT_ARGS
|
||||
rarfile.TEST_ARGS = rarfile.ORIG_TEST_ARGS
|
||||
logging.info("Using UnRAR from: %s", exe)
|
||||
logging.debug("Using UnRAR from: %s", exe)
|
||||
unrar = exe
|
||||
|
||||
return unrar
|
||||
|
|
|
@ -14,11 +14,9 @@ from subliminal import core
|
|||
from subliminal_patch import search_external_subtitles
|
||||
from bs4 import UnicodeDammit
|
||||
from itertools import islice
|
||||
from database import TableShows, TableEpisodes, TableMovies
|
||||
from peewee import fn, JOIN
|
||||
from functools import reduce
|
||||
|
||||
from get_args import args
|
||||
from database import database
|
||||
from get_languages import alpha2_from_alpha3, get_language_set
|
||||
from config import settings
|
||||
from helper import path_replace, path_replace_movie, path_replace_reverse, \
|
||||
|
@ -108,15 +106,14 @@ def store_subtitles(original_path, reversed_path):
|
|||
actual_subtitles.append([str(detected_language), path_replace_reverse(
|
||||
os.path.join(os.path.dirname(reversed_path), subtitle))])
|
||||
|
||||
update_count = TableEpisodes.update(
|
||||
{
|
||||
TableEpisodes.subtitles: str(actual_subtitles)
|
||||
}
|
||||
).where(
|
||||
TableEpisodes.path == original_path
|
||||
).execute()
|
||||
if update_count > 0:
|
||||
database.execute("UPDATE table_episodes SET subtitles=? WHERE path=?",
|
||||
(str(actual_subtitles), original_path))
|
||||
episode = database.execute("SELECT sonarrEpisodeId FROM table_episodes WHERE path=?",
|
||||
(original_path,), only_one=True)
|
||||
|
||||
if len(episode):
|
||||
logging.debug("BAZARR storing those languages to DB: " + str(actual_subtitles))
|
||||
list_missing_subtitles(epno=episode['sonarrEpisodeId'])
|
||||
else:
|
||||
logging.debug("BAZARR haven't been able to update existing subtitles to DB : " + str(actual_subtitles))
|
||||
else:
|
||||
|
@ -124,14 +121,6 @@ def store_subtitles(original_path, reversed_path):
|
|||
|
||||
logging.debug('BAZARR ended subtitles indexing for this file: ' + reversed_path)
|
||||
|
||||
episode = TableEpisodes.select(
|
||||
TableEpisodes.sonarr_episode_id
|
||||
).where(
|
||||
TableEpisodes.path == path_replace_reverse(file)
|
||||
).first()
|
||||
|
||||
list_missing_subtitles(epno=episode.sonarr_episode_id)
|
||||
|
||||
return actual_subtitles
|
||||
|
||||
|
||||
|
@ -216,15 +205,13 @@ def store_subtitles_movie(original_path, reversed_path):
|
|||
actual_subtitles.append([str(detected_language), path_replace_reverse_movie(
|
||||
os.path.join(os.path.dirname(reversed_path), dest_folder, subtitle))])
|
||||
|
||||
update_count = TableMovies.update(
|
||||
{
|
||||
TableMovies.subtitles: str(actual_subtitles)
|
||||
}
|
||||
).where(
|
||||
TableMovies.path == original_path
|
||||
).execute()
|
||||
if update_count > 0:
|
||||
database.execute("UPDATE table_movies SET subtitles=? WHERE path=?",
|
||||
(str(actual_subtitles), original_path))
|
||||
movie = database.execute("SELECT radarrId FROM table_movies WHERE path=?", (original_path,))
|
||||
|
||||
if len(movie):
|
||||
logging.debug("BAZARR storing those languages to DB: " + str(actual_subtitles))
|
||||
list_missing_subtitles_movies(no=movie[0]['radarrId'])
|
||||
else:
|
||||
logging.debug("BAZARR haven't been able to update existing subtitles to DB : " + str(actual_subtitles))
|
||||
else:
|
||||
|
@ -232,34 +219,21 @@ def store_subtitles_movie(original_path, reversed_path):
|
|||
|
||||
logging.debug('BAZARR ended subtitles indexing for this file: ' + reversed_path)
|
||||
|
||||
movie = TableMovies.select(
|
||||
TableMovies.radarr_id
|
||||
).where(
|
||||
TableMovies.path == path_replace_reverse_movie(file)
|
||||
).first()
|
||||
|
||||
list_missing_subtitles_movies(no=movie.radarr_id)
|
||||
|
||||
return actual_subtitles
|
||||
|
||||
|
||||
def list_missing_subtitles(no=None, epno=None):
|
||||
episodes_subtitles_clause = (TableShows.sonarr_series_id.is_null(False))
|
||||
if no is not None:
|
||||
episodes_subtitles_clause = (TableShows.sonarr_series_id == no)
|
||||
episodes_subtitles_clause = " WHERE table_episodes.sonarrSeriesId=" + str(no)
|
||||
elif epno is not None:
|
||||
episodes_subtitles_clause = (TableEpisodes.sonarr_episode_id == epno)
|
||||
episodes_subtitles = TableEpisodes.select(
|
||||
TableShows.sonarr_series_id,
|
||||
TableEpisodes.sonarr_episode_id,
|
||||
TableEpisodes.subtitles,
|
||||
TableShows.languages,
|
||||
TableShows.forced
|
||||
).join_from(
|
||||
TableEpisodes, TableShows, JOIN.LEFT_OUTER
|
||||
).where(
|
||||
reduce(operator.and_, episodes_subtitles_clause)
|
||||
).objects()
|
||||
episodes_subtitles_clause = " WHERE table_episodes.sonarrEpisodeId=" + str(epno)
|
||||
else:
|
||||
episodes_subtitles_clause = ""
|
||||
episodes_subtitles = database.execute("SELECT table_shows.sonarrSeriesId, table_episodes.sonarrEpisodeId, "
|
||||
"table_episodes.subtitles, table_shows.languages, table_shows.forced "
|
||||
"FROM table_episodes LEFT JOIN table_shows "
|
||||
"on table_episodes.sonarrSeriesId = table_shows.sonarrSeriesId" +
|
||||
episodes_subtitles_clause)
|
||||
|
||||
missing_subtitles_global = []
|
||||
use_embedded_subs = settings.general.getboolean('use_embedded_subs')
|
||||
|
@ -269,27 +243,27 @@ def list_missing_subtitles(no=None, epno=None):
|
|||
actual_subtitles = []
|
||||
desired_subtitles = []
|
||||
missing_subtitles = []
|
||||
if episode_subtitles.subtitles is not None:
|
||||
if episode_subtitles['subtitles'] is not None:
|
||||
if use_embedded_subs:
|
||||
actual_subtitles = ast.literal_eval(episode_subtitles.subtitles)
|
||||
actual_subtitles = ast.literal_eval(episode_subtitles['subtitles'])
|
||||
else:
|
||||
actual_subtitles_temp = ast.literal_eval(episode_subtitles.subtitles)
|
||||
actual_subtitles_temp = ast.literal_eval(episode_subtitles['subtitles'])
|
||||
for subtitle in actual_subtitles_temp:
|
||||
if subtitle[1] is not None:
|
||||
actual_subtitles.append(subtitle)
|
||||
if episode_subtitles.languages is not None:
|
||||
desired_subtitles = ast.literal_eval(episode_subtitles.languages)
|
||||
if episode_subtitles.forced == "True" and desired_subtitles is not None:
|
||||
if episode_subtitles['languages'] is not None:
|
||||
desired_subtitles = ast.literal_eval(episode_subtitles['languages'])
|
||||
if episode_subtitles['forced'] == "True" and desired_subtitles is not None:
|
||||
for i, desired_subtitle in enumerate(desired_subtitles):
|
||||
desired_subtitles[i] = desired_subtitle + ":forced"
|
||||
elif episode_subtitles.forced == "Both" and desired_subtitles is not None:
|
||||
elif episode_subtitles['forced'] == "Both" and desired_subtitles is not None:
|
||||
for desired_subtitle in desired_subtitles:
|
||||
desired_subtitles_temp.append(desired_subtitle)
|
||||
desired_subtitles_temp.append(desired_subtitle + ":forced")
|
||||
desired_subtitles = desired_subtitles_temp
|
||||
actual_subtitles_list = []
|
||||
if desired_subtitles is None:
|
||||
missing_subtitles_global.append(tuple(['[]', episode_subtitles.sonarr_episode_id]))
|
||||
missing_subtitles_global.append(tuple(['[]', episode_subtitles['sonarrEpisodeId']]))
|
||||
else:
|
||||
for item in actual_subtitles:
|
||||
if item[0] == "pt-BR":
|
||||
|
@ -299,31 +273,21 @@ def list_missing_subtitles(no=None, epno=None):
|
|||
else:
|
||||
actual_subtitles_list.append(item[0])
|
||||
missing_subtitles = list(set(desired_subtitles) - set(actual_subtitles_list))
|
||||
missing_subtitles_global.append(tuple([str(missing_subtitles), episode_subtitles.sonarr_episode_id]))
|
||||
missing_subtitles_global.append(tuple([str(missing_subtitles), episode_subtitles['sonarrEpisodeId']]))
|
||||
|
||||
for missing_subtitles_item in missing_subtitles_global:
|
||||
TableEpisodes.update(
|
||||
{
|
||||
TableEpisodes.missing_subtitles: missing_subtitles_item[0]
|
||||
}
|
||||
).where(
|
||||
TableEpisodes.sonarr_episode_id == missing_subtitles_item[1]
|
||||
).execute()
|
||||
database.execute("UPDATE table_episodes SET missing_subtitles=? WHERE sonarrEpisodeId=?",
|
||||
(missing_subtitles_item[0], missing_subtitles_item[1]))
|
||||
|
||||
|
||||
def list_missing_subtitles_movies(no=None):
|
||||
movies_subtitles_clause = (TableMovies.radarr_id.is_null(False))
|
||||
if no is not None:
|
||||
movies_subtitles_clause = (TableMovies.radarr_id == no)
|
||||
movies_subtitles_clause = " WHERE radarrId=" + str(no)
|
||||
else:
|
||||
movies_subtitles_clause = ""
|
||||
|
||||
movies_subtitles = TableMovies.select(
|
||||
TableMovies.radarr_id,
|
||||
TableMovies.subtitles,
|
||||
TableMovies.languages,
|
||||
TableMovies.forced
|
||||
).where(
|
||||
reduce(operator.and_, movies_subtitles_clause)
|
||||
)
|
||||
movies_subtitles = database.execute("SELECT radarrId, subtitles, languages, forced FROM table_movies" +
|
||||
movies_subtitles_clause)
|
||||
|
||||
missing_subtitles_global = []
|
||||
use_embedded_subs = settings.general.getboolean('use_embedded_subs')
|
||||
|
@ -333,27 +297,27 @@ def list_missing_subtitles_movies(no=None):
|
|||
actual_subtitles = []
|
||||
desired_subtitles = []
|
||||
missing_subtitles = []
|
||||
if movie_subtitles.subtitles is not None:
|
||||
if movie_subtitles['subtitles'] is not None:
|
||||
if use_embedded_subs:
|
||||
actual_subtitles = ast.literal_eval(movie_subtitles.subtitles)
|
||||
actual_subtitles = ast.literal_eval(movie_subtitles['subtitles'])
|
||||
else:
|
||||
actual_subtitles_temp = ast.literal_eval(movie_subtitles.subtitles)
|
||||
actual_subtitles_temp = ast.literal_eval(movie_subtitles['subtitles'])
|
||||
for subtitle in actual_subtitles_temp:
|
||||
if subtitle[1] is not None:
|
||||
actual_subtitles.append(subtitle)
|
||||
if movie_subtitles.languages is not None:
|
||||
desired_subtitles = ast.literal_eval(movie_subtitles.languages)
|
||||
if movie_subtitles.forced == "True" and desired_subtitles is not None:
|
||||
if movie_subtitles['languages'] is not None:
|
||||
desired_subtitles = ast.literal_eval(movie_subtitles['languages'])
|
||||
if movie_subtitles['forced'] == "True" and desired_subtitles is not None:
|
||||
for i, desired_subtitle in enumerate(desired_subtitles):
|
||||
desired_subtitles[i] = desired_subtitle + ":forced"
|
||||
elif movie_subtitles.forced == "Both" and desired_subtitles is not None:
|
||||
elif movie_subtitles['forced'] == "Both" and desired_subtitles is not None:
|
||||
for desired_subtitle in desired_subtitles:
|
||||
desired_subtitles_temp.append(desired_subtitle)
|
||||
desired_subtitles_temp.append(desired_subtitle + ":forced")
|
||||
desired_subtitles = desired_subtitles_temp
|
||||
actual_subtitles_list = []
|
||||
if desired_subtitles is None:
|
||||
missing_subtitles_global.append(tuple(['[]', movie_subtitles.radarr_id]))
|
||||
missing_subtitles_global.append(tuple(['[]', movie_subtitles['radarrId']]))
|
||||
else:
|
||||
for item in actual_subtitles:
|
||||
if item[0] == "pt-BR":
|
||||
|
@ -363,66 +327,49 @@ def list_missing_subtitles_movies(no=None):
|
|||
else:
|
||||
actual_subtitles_list.append(item[0])
|
||||
missing_subtitles = list(set(desired_subtitles) - set(actual_subtitles_list))
|
||||
missing_subtitles_global.append(tuple([str(missing_subtitles), movie_subtitles.radarr_id]))
|
||||
missing_subtitles_global.append(tuple([str(missing_subtitles), movie_subtitles['radarrId']]))
|
||||
|
||||
for missing_subtitles_item in missing_subtitles_global:
|
||||
TableMovies.update(
|
||||
{
|
||||
TableMovies.missing_subtitles: missing_subtitles_item[0]
|
||||
}
|
||||
).where(
|
||||
TableMovies.radarr_id == missing_subtitles_item[1]
|
||||
).execute()
|
||||
database.execute("UPDATE table_movies SET missing_subtitles=? WHERE radarrId=?",
|
||||
(missing_subtitles_item[0], missing_subtitles_item[1]))
|
||||
|
||||
|
||||
def series_full_scan_subtitles():
|
||||
episodes = TableEpisodes.select(
|
||||
TableEpisodes.path
|
||||
)
|
||||
count_episodes = episodes.count()
|
||||
episodes = database.execute("SELECT path FROM table_episodes")
|
||||
count_episodes = len(episodes)
|
||||
|
||||
for i, episode in enumerate(episodes, 1):
|
||||
notifications.write(msg='Updating all episodes subtitles from disk...',
|
||||
queue='list_subtitles_series', item=i, length=count_episodes)
|
||||
store_subtitles(episode.path, path_replace(episode.path))
|
||||
store_subtitles(episode['path'], path_replace(episode['path']))
|
||||
|
||||
gc.collect()
|
||||
|
||||
|
||||
def movies_full_scan_subtitles():
|
||||
movies = TableMovies.select(
|
||||
TableMovies.path
|
||||
)
|
||||
count_movies = movies.count()
|
||||
movies = database.execute("SELECT path FROM table_movies")
|
||||
count_movies = len(movies)
|
||||
|
||||
for i, movie in enumerate(movies, 1):
|
||||
notifications.write(msg='Updating all movies subtitles from disk...',
|
||||
queue='list_subtitles_movies', item=i, length=count_movies)
|
||||
store_subtitles_movie(movie.path, path_replace_movie(movie.path))
|
||||
store_subtitles_movie(movie['path'], path_replace_movie(movie['path']))
|
||||
|
||||
gc.collect()
|
||||
|
||||
|
||||
def series_scan_subtitles(no):
|
||||
episodes = TableEpisodes.select(
|
||||
TableEpisodes.path
|
||||
).where(
|
||||
TableEpisodes.sonarr_series_id == no
|
||||
)
|
||||
episodes = database.execute("SELECT path FROM table_episodes WHERE sonarrSeriesId=?", (no,))
|
||||
|
||||
for episode in episodes:
|
||||
store_subtitles(episode.path, path_replace(episode.path))
|
||||
store_subtitles(episode['path'], path_replace(episode['path']))
|
||||
|
||||
|
||||
def movies_scan_subtitles(no):
|
||||
movies = TableMovies.select(
|
||||
TableMovies.path
|
||||
).where(
|
||||
TableMovies.radarr_id == no
|
||||
)
|
||||
movies = database.execute("SELECT path FROM table_movies WHERE radarrId=?", (no,))
|
||||
|
||||
for movie in movies:
|
||||
store_subtitles_movie(movie.path, path_replace_movie(movie.path))
|
||||
store_subtitles_movie(movie['path'], path_replace_movie(movie['path']))
|
||||
|
||||
|
||||
def get_external_subtitles_path(file, subtitle):
|
||||
|
|
|
@ -77,7 +77,7 @@ def configure_logging(debug=False):
|
|||
logger.addHandler(fh)
|
||||
|
||||
if debug:
|
||||
logging.getLogger("peewee").setLevel(logging.INFO)
|
||||
logging.getLogger("sqlite3worker").setLevel(logging.DEBUG)
|
||||
logging.getLogger("apscheduler").setLevel(logging.DEBUG)
|
||||
logging.getLogger("subliminal").setLevel(logging.DEBUG)
|
||||
logging.getLogger("subliminal_patch").setLevel(logging.DEBUG)
|
||||
|
@ -89,6 +89,7 @@ def configure_logging(debug=False):
|
|||
logging.debug('Operating system: %s', platform.platform())
|
||||
logging.debug('Python version: %s', platform.python_version())
|
||||
else:
|
||||
logging.getLogger("sqlite3worker").setLevel(logging.CRITICAL)
|
||||
logging.getLogger("apscheduler").setLevel(logging.WARNING)
|
||||
logging.getLogger("subliminal").setLevel(logging.CRITICAL)
|
||||
logging.getLogger("subliminal_patch").setLevel(logging.CRITICAL)
|
||||
|
|
816
bazarr/main.py
816
bazarr/main.py
File diff suppressed because it is too large
Load diff
|
@ -6,7 +6,7 @@ import os
|
|||
import logging
|
||||
|
||||
from get_args import args
|
||||
from database import TableSettingsNotifier, TableShows, TableEpisodes, TableMovies
|
||||
from database import database
|
||||
|
||||
|
||||
def update_notifier():
|
||||
|
@ -19,13 +19,11 @@ def update_notifier():
|
|||
notifiers_new = []
|
||||
notifiers_old = []
|
||||
|
||||
notifiers_current_db = TableSettingsNotifier.select(
|
||||
TableSettingsNotifier.name
|
||||
)
|
||||
notifiers_current_db = database.execute("SELECT name FROM table_settings_notifier")
|
||||
|
||||
notifiers_current = []
|
||||
for notifier in notifiers_current_db:
|
||||
notifiers_current.append(notifier.name)
|
||||
notifiers_current.append(notifier['name'])
|
||||
|
||||
for x in results['schemas']:
|
||||
if x['service_name'] not in notifiers_current:
|
||||
|
@ -38,60 +36,34 @@ def update_notifier():
|
|||
notifiers_to_delete = list(set(notifier_current) - set(notifiers_old))
|
||||
|
||||
for notifier_new in notifiers_new:
|
||||
TableSettingsNotifier.insert(
|
||||
{
|
||||
TableSettingsNotifier.name: notifier_new,
|
||||
TableSettingsNotifier.enabled: 0
|
||||
}
|
||||
).execute()
|
||||
database.execute("INSERT INTO table_settings_notifier (name, enabled) VALUES (?, ?)", (notifier_new, 0))
|
||||
|
||||
for notifier_to_delete in notifiers_to_delete:
|
||||
TableSettingsNotifier.delete().where(
|
||||
TableSettingsNotifier.name == notifier_to_delete
|
||||
).execute()
|
||||
database.execute("DELETE FROM table_settings_notifier WHERE name=?", (notifier_to_delete,))
|
||||
|
||||
|
||||
def get_notifier_providers():
|
||||
providers = TableSettingsNotifier.select(
|
||||
TableSettingsNotifier.name,
|
||||
TableSettingsNotifier.url
|
||||
).where(
|
||||
TableSettingsNotifier.enabled == 1
|
||||
)
|
||||
|
||||
providers = database.execute("SELECT name, url FROM table_settings_notifier WHERE enabled=1")
|
||||
return providers
|
||||
|
||||
|
||||
def get_series_name(sonarrSeriesId):
|
||||
data = TableShows.select(
|
||||
TableShows.title
|
||||
).where(
|
||||
TableShows.sonarr_series_id == sonarrSeriesId
|
||||
).first()
|
||||
data = database.execute("SELECT title FROM table_shows WHERE sonarrSeriesId=?", (sonarrSeriesId,), only_one=True)
|
||||
|
||||
return data.title
|
||||
return data['title'] or None
|
||||
|
||||
|
||||
def get_episode_name(sonarrEpisodeId):
|
||||
data = TableEpisodes.select(
|
||||
TableEpisodes.title,
|
||||
TableEpisodes.season,
|
||||
TableEpisodes.episode
|
||||
).where(
|
||||
TableEpisodes.sonarr_episode_id == sonarrEpisodeId
|
||||
).first()
|
||||
data = database.execute("SELECT title, season, episode FROM table_episodes WHERE sonarrEpisodeId=?",
|
||||
(sonarrEpisodeId,), only_one=True)
|
||||
|
||||
return data.title, data.season, data.episode
|
||||
return data['title'], data['season'], data['episode']
|
||||
|
||||
|
||||
def get_movies_name(radarrId):
|
||||
data = TableMovies.select(
|
||||
TableMovies.title
|
||||
).where(
|
||||
TableMovies.radarr_id == radarrId
|
||||
).first()
|
||||
|
||||
return data.title
|
||||
data = database.execute("SELECT title FROM table_movies WHERE radarrId=?", (radarrId,), only_one=True)
|
||||
|
||||
return data['title']
|
||||
|
||||
|
||||
def send_notifications(sonarrSeriesId, sonarrEpisodeId, message):
|
||||
|
@ -102,8 +74,8 @@ def send_notifications(sonarrSeriesId, sonarrEpisodeId, message):
|
|||
apobj = apprise.Apprise()
|
||||
|
||||
for provider in providers:
|
||||
if provider.url is not None:
|
||||
apobj.add(provider.url)
|
||||
if provider['url'] is not None:
|
||||
apobj.add(provider['url'])
|
||||
|
||||
apobj.notify(
|
||||
title='Bazarr notification',
|
||||
|
@ -118,8 +90,8 @@ def send_notifications_movie(radarrId, message):
|
|||
apobj = apprise.Apprise()
|
||||
|
||||
for provider in providers:
|
||||
if provider.url is not None:
|
||||
apobj.add(provider.url)
|
||||
if provider['url'] is not None:
|
||||
apobj.add(provider['url'])
|
||||
|
||||
apobj.notify(
|
||||
title='Bazarr notification',
|
||||
|
|
|
@ -11,7 +11,7 @@ import requests
|
|||
from whichcraft import which
|
||||
from get_args import args
|
||||
from config import settings, url_sonarr, url_radarr
|
||||
from database import TableHistory, TableHistoryMovie
|
||||
from database import database
|
||||
|
||||
from subliminal import region as subliminal_cache_region
|
||||
import datetime
|
||||
|
@ -20,35 +20,18 @@ import glob
|
|||
|
||||
def history_log(action, sonarrSeriesId, sonarrEpisodeId, description, video_path=None, language=None, provider=None,
|
||||
score=None, forced=False):
|
||||
TableHistory.insert(
|
||||
{
|
||||
TableHistory.action: action,
|
||||
TableHistory.sonarr_series_id: sonarrSeriesId,
|
||||
TableHistory.sonarr_episode_id: sonarrEpisodeId,
|
||||
TableHistory.timestamp: time.time(),
|
||||
TableHistory.description: description,
|
||||
TableHistory.video_path: video_path,
|
||||
TableHistory.language: language,
|
||||
TableHistory.provider: provider,
|
||||
TableHistory.score: score
|
||||
}
|
||||
).execute()
|
||||
database.execute("INSERT INTO table_history (action, sonarrSeriesId, sonarrEpisodeId, timestamp, description,"
|
||||
"video_path, language, provider, score) VALUES (?,?,?,?,?,?,?,?,?)", (action, sonarrSeriesId,
|
||||
sonarrEpisodeId, time.time(),
|
||||
description, video_path,
|
||||
language, provider, score))
|
||||
|
||||
|
||||
def history_log_movie(action, radarrId, description, video_path=None, language=None, provider=None, score=None,
|
||||
forced=False):
|
||||
TableHistoryMovie.insert(
|
||||
{
|
||||
TableHistoryMovie.action: action,
|
||||
TableHistoryMovie.radarr_id: radarrId,
|
||||
TableHistoryMovie.timestamp: time.time(),
|
||||
TableHistoryMovie.description: description,
|
||||
TableHistoryMovie.video_path: video_path,
|
||||
TableHistoryMovie.language: language,
|
||||
TableHistoryMovie.provider: provider,
|
||||
TableHistoryMovie.score: score
|
||||
}
|
||||
).execute()
|
||||
database.execute("INSERT INTO table_history_movie (action, radarrId, timestamp, description, video_path, language, "
|
||||
"provider, score) VALUES (?,?,?,?,?,?,?,?)", (action, radarrId, time.time(), description,
|
||||
video_path, language, provider, score))
|
||||
|
||||
|
||||
def get_binary(name):
|
||||
|
@ -106,7 +89,7 @@ def get_sonarr_version():
|
|||
try:
|
||||
sonarr_version = requests.get(sv, timeout=60, verify=False).json()['version']
|
||||
except Exception as e:
|
||||
logging.DEBUG('BAZARR cannot get Sonarr version')
|
||||
logging.debug('BAZARR cannot get Sonarr version')
|
||||
|
||||
return sonarr_version
|
||||
|
||||
|
@ -137,7 +120,7 @@ def get_radarr_version():
|
|||
try:
|
||||
radarr_version = requests.get(rv, timeout=60, verify=False).json()['version']
|
||||
except Exception as e:
|
||||
logging.DEBUG('BAZARR cannot get Radarr version')
|
||||
logging.debug('BAZARR cannot get Radarr version')
|
||||
|
||||
return radarr_version
|
||||
|
||||
|
|
7508
libs/peewee.py
7508
libs/peewee.py
File diff suppressed because it is too large
Load diff
|
@ -1,136 +0,0 @@
|
|||
"""
|
||||
Peewee integration with APSW, "another python sqlite wrapper".
|
||||
|
||||
Project page: https://rogerbinns.github.io/apsw/
|
||||
|
||||
APSW is a really neat library that provides a thin wrapper on top of SQLite's
|
||||
C interface.
|
||||
|
||||
Here are just a few reasons to use APSW, taken from the documentation:
|
||||
|
||||
* APSW gives all functionality of SQLite, including virtual tables, virtual
|
||||
file system, blob i/o, backups and file control.
|
||||
* Connections can be shared across threads without any additional locking.
|
||||
* Transactions are managed explicitly by your code.
|
||||
* APSW can handle nested transactions.
|
||||
* Unicode is handled correctly.
|
||||
* APSW is faster.
|
||||
"""
|
||||
import apsw
|
||||
from peewee import *
|
||||
from peewee import __exception_wrapper__
|
||||
from peewee import BooleanField as _BooleanField
|
||||
from peewee import DateField as _DateField
|
||||
from peewee import DateTimeField as _DateTimeField
|
||||
from peewee import DecimalField as _DecimalField
|
||||
from peewee import TimeField as _TimeField
|
||||
from peewee import logger
|
||||
|
||||
from playhouse.sqlite_ext import SqliteExtDatabase
|
||||
|
||||
|
||||
class APSWDatabase(SqliteExtDatabase):
|
||||
server_version = tuple(int(i) for i in apsw.sqlitelibversion().split('.'))
|
||||
|
||||
def __init__(self, database, **kwargs):
|
||||
self._modules = {}
|
||||
super(APSWDatabase, self).__init__(database, **kwargs)
|
||||
|
||||
def register_module(self, mod_name, mod_inst):
|
||||
self._modules[mod_name] = mod_inst
|
||||
if not self.is_closed():
|
||||
self.connection().createmodule(mod_name, mod_inst)
|
||||
|
||||
def unregister_module(self, mod_name):
|
||||
del(self._modules[mod_name])
|
||||
|
||||
def _connect(self):
|
||||
conn = apsw.Connection(self.database, **self.connect_params)
|
||||
if self._timeout is not None:
|
||||
conn.setbusytimeout(self._timeout * 1000)
|
||||
try:
|
||||
self._add_conn_hooks(conn)
|
||||
except:
|
||||
conn.close()
|
||||
raise
|
||||
return conn
|
||||
|
||||
def _add_conn_hooks(self, conn):
|
||||
super(APSWDatabase, self)._add_conn_hooks(conn)
|
||||
self._load_modules(conn) # APSW-only.
|
||||
|
||||
def _load_modules(self, conn):
|
||||
for mod_name, mod_inst in self._modules.items():
|
||||
conn.createmodule(mod_name, mod_inst)
|
||||
return conn
|
||||
|
||||
def _load_aggregates(self, conn):
|
||||
for name, (klass, num_params) in self._aggregates.items():
|
||||
def make_aggregate():
|
||||
return (klass(), klass.step, klass.finalize)
|
||||
conn.createaggregatefunction(name, make_aggregate)
|
||||
|
||||
def _load_collations(self, conn):
|
||||
for name, fn in self._collations.items():
|
||||
conn.createcollation(name, fn)
|
||||
|
||||
def _load_functions(self, conn):
|
||||
for name, (fn, num_params) in self._functions.items():
|
||||
conn.createscalarfunction(name, fn, num_params)
|
||||
|
||||
def _load_extensions(self, conn):
|
||||
conn.enableloadextension(True)
|
||||
for extension in self._extensions:
|
||||
conn.loadextension(extension)
|
||||
|
||||
def load_extension(self, extension):
|
||||
self._extensions.add(extension)
|
||||
if not self.is_closed():
|
||||
conn = self.connection()
|
||||
conn.enableloadextension(True)
|
||||
conn.loadextension(extension)
|
||||
|
||||
def last_insert_id(self, cursor, query_type=None):
|
||||
return cursor.getconnection().last_insert_rowid()
|
||||
|
||||
def rows_affected(self, cursor):
|
||||
return cursor.getconnection().changes()
|
||||
|
||||
def begin(self, lock_type='deferred'):
|
||||
self.cursor().execute('begin %s;' % lock_type)
|
||||
|
||||
def commit(self):
|
||||
self.cursor().execute('commit;')
|
||||
|
||||
def rollback(self):
|
||||
self.cursor().execute('rollback;')
|
||||
|
||||
def execute_sql(self, sql, params=None, commit=True):
|
||||
logger.debug((sql, params))
|
||||
with __exception_wrapper__:
|
||||
cursor = self.cursor()
|
||||
cursor.execute(sql, params or ())
|
||||
return cursor
|
||||
|
||||
|
||||
def nh(s, v):
|
||||
if v is not None:
|
||||
return str(v)
|
||||
|
||||
class BooleanField(_BooleanField):
|
||||
def db_value(self, v):
|
||||
v = super(BooleanField, self).db_value(v)
|
||||
if v is not None:
|
||||
return v and 1 or 0
|
||||
|
||||
class DateField(_DateField):
|
||||
db_value = nh
|
||||
|
||||
class TimeField(_TimeField):
|
||||
db_value = nh
|
||||
|
||||
class DateTimeField(_DateTimeField):
|
||||
db_value = nh
|
||||
|
||||
class DecimalField(_DecimalField):
|
||||
db_value = nh
|
|
@ -1,452 +0,0 @@
|
|||
import csv
|
||||
import datetime
|
||||
from decimal import Decimal
|
||||
import json
|
||||
import operator
|
||||
try:
|
||||
from urlparse import urlparse
|
||||
except ImportError:
|
||||
from urllib.parse import urlparse
|
||||
import sys
|
||||
|
||||
from peewee import *
|
||||
from playhouse.db_url import connect
|
||||
from playhouse.migrate import migrate
|
||||
from playhouse.migrate import SchemaMigrator
|
||||
from playhouse.reflection import Introspector
|
||||
|
||||
if sys.version_info[0] == 3:
|
||||
basestring = str
|
||||
from functools import reduce
|
||||
def open_file(f, mode):
|
||||
return open(f, mode, encoding='utf8')
|
||||
else:
|
||||
open_file = open
|
||||
|
||||
|
||||
class DataSet(object):
|
||||
def __init__(self, url, bare_fields=False):
|
||||
if isinstance(url, Database):
|
||||
self._url = None
|
||||
self._database = url
|
||||
self._database_path = self._database.database
|
||||
else:
|
||||
self._url = url
|
||||
parse_result = urlparse(url)
|
||||
self._database_path = parse_result.path[1:]
|
||||
|
||||
# Connect to the database.
|
||||
self._database = connect(url)
|
||||
|
||||
self._database.connect()
|
||||
|
||||
# Introspect the database and generate models.
|
||||
self._introspector = Introspector.from_database(self._database)
|
||||
self._models = self._introspector.generate_models(
|
||||
skip_invalid=True,
|
||||
literal_column_names=True,
|
||||
bare_fields=bare_fields)
|
||||
self._migrator = SchemaMigrator.from_database(self._database)
|
||||
|
||||
class BaseModel(Model):
|
||||
class Meta:
|
||||
database = self._database
|
||||
self._base_model = BaseModel
|
||||
self._export_formats = self.get_export_formats()
|
||||
self._import_formats = self.get_import_formats()
|
||||
|
||||
def __repr__(self):
|
||||
return '<DataSet: %s>' % self._database_path
|
||||
|
||||
def get_export_formats(self):
|
||||
return {
|
||||
'csv': CSVExporter,
|
||||
'json': JSONExporter,
|
||||
'tsv': TSVExporter}
|
||||
|
||||
def get_import_formats(self):
|
||||
return {
|
||||
'csv': CSVImporter,
|
||||
'json': JSONImporter,
|
||||
'tsv': TSVImporter}
|
||||
|
||||
def __getitem__(self, table):
|
||||
if table not in self._models and table in self.tables:
|
||||
self.update_cache(table)
|
||||
return Table(self, table, self._models.get(table))
|
||||
|
||||
@property
|
||||
def tables(self):
|
||||
return self._database.get_tables()
|
||||
|
||||
def __contains__(self, table):
|
||||
return table in self.tables
|
||||
|
||||
def connect(self):
|
||||
self._database.connect()
|
||||
|
||||
def close(self):
|
||||
self._database.close()
|
||||
|
||||
def update_cache(self, table=None):
|
||||
if table:
|
||||
dependencies = [table]
|
||||
if table in self._models:
|
||||
model_class = self._models[table]
|
||||
dependencies.extend([
|
||||
related._meta.table_name for _, related, _ in
|
||||
model_class._meta.model_graph()])
|
||||
else:
|
||||
dependencies.extend(self.get_table_dependencies(table))
|
||||
else:
|
||||
dependencies = None # Update all tables.
|
||||
self._models = {}
|
||||
updated = self._introspector.generate_models(
|
||||
skip_invalid=True,
|
||||
table_names=dependencies,
|
||||
literal_column_names=True)
|
||||
self._models.update(updated)
|
||||
|
||||
def get_table_dependencies(self, table):
|
||||
stack = [table]
|
||||
accum = []
|
||||
seen = set()
|
||||
while stack:
|
||||
table = stack.pop()
|
||||
for fk_meta in self._database.get_foreign_keys(table):
|
||||
dest = fk_meta.dest_table
|
||||
if dest not in seen:
|
||||
stack.append(dest)
|
||||
accum.append(dest)
|
||||
return accum
|
||||
|
||||
def __enter__(self):
|
||||
self.connect()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
if not self._database.is_closed():
|
||||
self.close()
|
||||
|
||||
def query(self, sql, params=None, commit=True):
|
||||
return self._database.execute_sql(sql, params, commit)
|
||||
|
||||
def transaction(self):
|
||||
if self._database.transaction_depth() == 0:
|
||||
return self._database.transaction()
|
||||
else:
|
||||
return self._database.savepoint()
|
||||
|
||||
def _check_arguments(self, filename, file_obj, format, format_dict):
|
||||
if filename and file_obj:
|
||||
raise ValueError('file is over-specified. Please use either '
|
||||
'filename or file_obj, but not both.')
|
||||
if not filename and not file_obj:
|
||||
raise ValueError('A filename or file-like object must be '
|
||||
'specified.')
|
||||
if format not in format_dict:
|
||||
valid_formats = ', '.join(sorted(format_dict.keys()))
|
||||
raise ValueError('Unsupported format "%s". Use one of %s.' % (
|
||||
format, valid_formats))
|
||||
|
||||
def freeze(self, query, format='csv', filename=None, file_obj=None,
|
||||
**kwargs):
|
||||
self._check_arguments(filename, file_obj, format, self._export_formats)
|
||||
if filename:
|
||||
file_obj = open_file(filename, 'w')
|
||||
|
||||
exporter = self._export_formats[format](query)
|
||||
exporter.export(file_obj, **kwargs)
|
||||
|
||||
if filename:
|
||||
file_obj.close()
|
||||
|
||||
def thaw(self, table, format='csv', filename=None, file_obj=None,
|
||||
strict=False, **kwargs):
|
||||
self._check_arguments(filename, file_obj, format, self._export_formats)
|
||||
if filename:
|
||||
file_obj = open_file(filename, 'r')
|
||||
|
||||
importer = self._import_formats[format](self[table], strict)
|
||||
count = importer.load(file_obj, **kwargs)
|
||||
|
||||
if filename:
|
||||
file_obj.close()
|
||||
|
||||
return count
|
||||
|
||||
|
||||
class Table(object):
|
||||
def __init__(self, dataset, name, model_class):
|
||||
self.dataset = dataset
|
||||
self.name = name
|
||||
if model_class is None:
|
||||
model_class = self._create_model()
|
||||
model_class.create_table()
|
||||
self.dataset._models[name] = model_class
|
||||
|
||||
@property
|
||||
def model_class(self):
|
||||
return self.dataset._models[self.name]
|
||||
|
||||
def __repr__(self):
|
||||
return '<Table: %s>' % self.name
|
||||
|
||||
def __len__(self):
|
||||
return self.find().count()
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.find().iterator())
|
||||
|
||||
def _create_model(self):
|
||||
class Meta:
|
||||
table_name = self.name
|
||||
return type(
|
||||
str(self.name),
|
||||
(self.dataset._base_model,),
|
||||
{'Meta': Meta})
|
||||
|
||||
def create_index(self, columns, unique=False):
|
||||
self.dataset._database.create_index(
|
||||
self.model_class,
|
||||
columns,
|
||||
unique=unique)
|
||||
|
||||
def _guess_field_type(self, value):
|
||||
if isinstance(value, basestring):
|
||||
return TextField
|
||||
if isinstance(value, (datetime.date, datetime.datetime)):
|
||||
return DateTimeField
|
||||
elif value is True or value is False:
|
||||
return BooleanField
|
||||
elif isinstance(value, int):
|
||||
return IntegerField
|
||||
elif isinstance(value, float):
|
||||
return FloatField
|
||||
elif isinstance(value, Decimal):
|
||||
return DecimalField
|
||||
return TextField
|
||||
|
||||
@property
|
||||
def columns(self):
|
||||
return [f.name for f in self.model_class._meta.sorted_fields]
|
||||
|
||||
def _migrate_new_columns(self, data):
|
||||
new_keys = set(data) - set(self.model_class._meta.fields)
|
||||
if new_keys:
|
||||
operations = []
|
||||
for key in new_keys:
|
||||
field_class = self._guess_field_type(data[key])
|
||||
field = field_class(null=True)
|
||||
operations.append(
|
||||
self.dataset._migrator.add_column(self.name, key, field))
|
||||
field.bind(self.model_class, key)
|
||||
|
||||
migrate(*operations)
|
||||
|
||||
self.dataset.update_cache(self.name)
|
||||
|
||||
def __getitem__(self, item):
|
||||
try:
|
||||
return self.model_class[item]
|
||||
except self.model_class.DoesNotExist:
|
||||
pass
|
||||
|
||||
def __setitem__(self, item, value):
|
||||
if not isinstance(value, dict):
|
||||
raise ValueError('Table.__setitem__() value must be a dict')
|
||||
|
||||
pk = self.model_class._meta.primary_key
|
||||
value[pk.name] = item
|
||||
|
||||
try:
|
||||
with self.dataset.transaction() as txn:
|
||||
self.insert(**value)
|
||||
except IntegrityError:
|
||||
self.dataset.update_cache(self.name)
|
||||
self.update(columns=[pk.name], **value)
|
||||
|
||||
def __delitem__(self, item):
|
||||
del self.model_class[item]
|
||||
|
||||
def insert(self, **data):
|
||||
self._migrate_new_columns(data)
|
||||
return self.model_class.insert(**data).execute()
|
||||
|
||||
def _apply_where(self, query, filters, conjunction=None):
|
||||
conjunction = conjunction or operator.and_
|
||||
if filters:
|
||||
expressions = [
|
||||
(self.model_class._meta.fields[column] == value)
|
||||
for column, value in filters.items()]
|
||||
query = query.where(reduce(conjunction, expressions))
|
||||
return query
|
||||
|
||||
def update(self, columns=None, conjunction=None, **data):
|
||||
self._migrate_new_columns(data)
|
||||
filters = {}
|
||||
if columns:
|
||||
for column in columns:
|
||||
filters[column] = data.pop(column)
|
||||
|
||||
return self._apply_where(
|
||||
self.model_class.update(**data),
|
||||
filters,
|
||||
conjunction).execute()
|
||||
|
||||
def _query(self, **query):
|
||||
return self._apply_where(self.model_class.select(), query)
|
||||
|
||||
def find(self, **query):
|
||||
return self._query(**query).dicts()
|
||||
|
||||
def find_one(self, **query):
|
||||
try:
|
||||
return self.find(**query).get()
|
||||
except self.model_class.DoesNotExist:
|
||||
return None
|
||||
|
||||
def all(self):
|
||||
return self.find()
|
||||
|
||||
def delete(self, **query):
|
||||
return self._apply_where(self.model_class.delete(), query).execute()
|
||||
|
||||
def freeze(self, *args, **kwargs):
|
||||
return self.dataset.freeze(self.all(), *args, **kwargs)
|
||||
|
||||
def thaw(self, *args, **kwargs):
|
||||
return self.dataset.thaw(self.name, *args, **kwargs)
|
||||
|
||||
|
||||
class Exporter(object):
|
||||
def __init__(self, query):
|
||||
self.query = query
|
||||
|
||||
def export(self, file_obj):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class JSONExporter(Exporter):
|
||||
def __init__(self, query, iso8601_datetimes=False):
|
||||
super(JSONExporter, self).__init__(query)
|
||||
self.iso8601_datetimes = iso8601_datetimes
|
||||
|
||||
def _make_default(self):
|
||||
datetime_types = (datetime.datetime, datetime.date, datetime.time)
|
||||
|
||||
if self.iso8601_datetimes:
|
||||
def default(o):
|
||||
if isinstance(o, datetime_types):
|
||||
return o.isoformat()
|
||||
elif isinstance(o, Decimal):
|
||||
return str(o)
|
||||
raise TypeError('Unable to serialize %r as JSON' % o)
|
||||
else:
|
||||
def default(o):
|
||||
if isinstance(o, datetime_types + (Decimal,)):
|
||||
return str(o)
|
||||
raise TypeError('Unable to serialize %r as JSON' % o)
|
||||
return default
|
||||
|
||||
def export(self, file_obj, **kwargs):
|
||||
json.dump(
|
||||
list(self.query),
|
||||
file_obj,
|
||||
default=self._make_default(),
|
||||
**kwargs)
|
||||
|
||||
|
||||
class CSVExporter(Exporter):
|
||||
def export(self, file_obj, header=True, **kwargs):
|
||||
writer = csv.writer(file_obj, **kwargs)
|
||||
tuples = self.query.tuples().execute()
|
||||
tuples.initialize()
|
||||
if header and getattr(tuples, 'columns', None):
|
||||
writer.writerow([column for column in tuples.columns])
|
||||
for row in tuples:
|
||||
writer.writerow(row)
|
||||
|
||||
|
||||
class TSVExporter(CSVExporter):
|
||||
def export(self, file_obj, header=True, **kwargs):
|
||||
kwargs.setdefault('delimiter', '\t')
|
||||
return super(TSVExporter, self).export(file_obj, header, **kwargs)
|
||||
|
||||
|
||||
class Importer(object):
|
||||
def __init__(self, table, strict=False):
|
||||
self.table = table
|
||||
self.strict = strict
|
||||
|
||||
model = self.table.model_class
|
||||
self.columns = model._meta.columns
|
||||
self.columns.update(model._meta.fields)
|
||||
|
||||
def load(self, file_obj):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class JSONImporter(Importer):
|
||||
def load(self, file_obj, **kwargs):
|
||||
data = json.load(file_obj, **kwargs)
|
||||
count = 0
|
||||
|
||||
for row in data:
|
||||
if self.strict:
|
||||
obj = {}
|
||||
for key in row:
|
||||
field = self.columns.get(key)
|
||||
if field is not None:
|
||||
obj[field.name] = field.python_value(row[key])
|
||||
else:
|
||||
obj = row
|
||||
|
||||
if obj:
|
||||
self.table.insert(**obj)
|
||||
count += 1
|
||||
|
||||
return count
|
||||
|
||||
|
||||
class CSVImporter(Importer):
|
||||
def load(self, file_obj, header=True, **kwargs):
|
||||
count = 0
|
||||
reader = csv.reader(file_obj, **kwargs)
|
||||
if header:
|
||||
try:
|
||||
header_keys = next(reader)
|
||||
except StopIteration:
|
||||
return count
|
||||
|
||||
if self.strict:
|
||||
header_fields = []
|
||||
for idx, key in enumerate(header_keys):
|
||||
if key in self.columns:
|
||||
header_fields.append((idx, self.columns[key]))
|
||||
else:
|
||||
header_fields = list(enumerate(header_keys))
|
||||
else:
|
||||
header_fields = list(enumerate(self.model._meta.sorted_fields))
|
||||
|
||||
if not header_fields:
|
||||
return count
|
||||
|
||||
for row in reader:
|
||||
obj = {}
|
||||
for idx, field in header_fields:
|
||||
if self.strict:
|
||||
obj[field.name] = field.python_value(row[idx])
|
||||
else:
|
||||
obj[field] = row[idx]
|
||||
|
||||
self.table.insert(**obj)
|
||||
count += 1
|
||||
|
||||
return count
|
||||
|
||||
|
||||
class TSVImporter(CSVImporter):
|
||||
def load(self, file_obj, header=True, **kwargs):
|
||||
kwargs.setdefault('delimiter', '\t')
|
||||
return super(TSVImporter, self).load(file_obj, header, **kwargs)
|
|
@ -1,124 +0,0 @@
|
|||
try:
|
||||
from urlparse import parse_qsl, unquote, urlparse
|
||||
except ImportError:
|
||||
from urllib.parse import parse_qsl, unquote, urlparse
|
||||
|
||||
from peewee import *
|
||||
from playhouse.pool import PooledMySQLDatabase
|
||||
from playhouse.pool import PooledPostgresqlDatabase
|
||||
from playhouse.pool import PooledSqliteDatabase
|
||||
from playhouse.pool import PooledSqliteExtDatabase
|
||||
from playhouse.sqlite_ext import SqliteExtDatabase
|
||||
|
||||
|
||||
schemes = {
|
||||
'mysql': MySQLDatabase,
|
||||
'mysql+pool': PooledMySQLDatabase,
|
||||
'postgres': PostgresqlDatabase,
|
||||
'postgresql': PostgresqlDatabase,
|
||||
'postgres+pool': PooledPostgresqlDatabase,
|
||||
'postgresql+pool': PooledPostgresqlDatabase,
|
||||
'sqlite': SqliteDatabase,
|
||||
'sqliteext': SqliteExtDatabase,
|
||||
'sqlite+pool': PooledSqliteDatabase,
|
||||
'sqliteext+pool': PooledSqliteExtDatabase,
|
||||
}
|
||||
|
||||
def register_database(db_class, *names):
|
||||
global schemes
|
||||
for name in names:
|
||||
schemes[name] = db_class
|
||||
|
||||
def parseresult_to_dict(parsed, unquote_password=False):
|
||||
|
||||
# urlparse in python 2.6 is broken so query will be empty and instead
|
||||
# appended to path complete with '?'
|
||||
path_parts = parsed.path[1:].split('?')
|
||||
try:
|
||||
query = path_parts[1]
|
||||
except IndexError:
|
||||
query = parsed.query
|
||||
|
||||
connect_kwargs = {'database': path_parts[0]}
|
||||
if parsed.username:
|
||||
connect_kwargs['user'] = parsed.username
|
||||
if parsed.password:
|
||||
connect_kwargs['password'] = parsed.password
|
||||
if unquote_password:
|
||||
connect_kwargs['password'] = unquote(connect_kwargs['password'])
|
||||
if parsed.hostname:
|
||||
connect_kwargs['host'] = parsed.hostname
|
||||
if parsed.port:
|
||||
connect_kwargs['port'] = parsed.port
|
||||
|
||||
# Adjust parameters for MySQL.
|
||||
if parsed.scheme == 'mysql' and 'password' in connect_kwargs:
|
||||
connect_kwargs['passwd'] = connect_kwargs.pop('password')
|
||||
elif 'sqlite' in parsed.scheme and not connect_kwargs['database']:
|
||||
connect_kwargs['database'] = ':memory:'
|
||||
|
||||
# Get additional connection args from the query string
|
||||
qs_args = parse_qsl(query, keep_blank_values=True)
|
||||
for key, value in qs_args:
|
||||
if value.lower() == 'false':
|
||||
value = False
|
||||
elif value.lower() == 'true':
|
||||
value = True
|
||||
elif value.isdigit():
|
||||
value = int(value)
|
||||
elif '.' in value and all(p.isdigit() for p in value.split('.', 1)):
|
||||
try:
|
||||
value = float(value)
|
||||
except ValueError:
|
||||
pass
|
||||
elif value.lower() in ('null', 'none'):
|
||||
value = None
|
||||
|
||||
connect_kwargs[key] = value
|
||||
|
||||
return connect_kwargs
|
||||
|
||||
def parse(url, unquote_password=False):
|
||||
parsed = urlparse(url)
|
||||
return parseresult_to_dict(parsed, unquote_password)
|
||||
|
||||
def connect(url, unquote_password=False, **connect_params):
|
||||
parsed = urlparse(url)
|
||||
connect_kwargs = parseresult_to_dict(parsed, unquote_password)
|
||||
connect_kwargs.update(connect_params)
|
||||
database_class = schemes.get(parsed.scheme)
|
||||
|
||||
if database_class is None:
|
||||
if database_class in schemes:
|
||||
raise RuntimeError('Attempted to use "%s" but a required library '
|
||||
'could not be imported.' % parsed.scheme)
|
||||
else:
|
||||
raise RuntimeError('Unrecognized or unsupported scheme: "%s".' %
|
||||
parsed.scheme)
|
||||
|
||||
return database_class(**connect_kwargs)
|
||||
|
||||
# Conditionally register additional databases.
|
||||
try:
|
||||
from playhouse.pool import PooledPostgresqlExtDatabase
|
||||
except ImportError:
|
||||
pass
|
||||
else:
|
||||
register_database(
|
||||
PooledPostgresqlExtDatabase,
|
||||
'postgresext+pool',
|
||||
'postgresqlext+pool')
|
||||
|
||||
try:
|
||||
from playhouse.apsw_ext import APSWDatabase
|
||||
except ImportError:
|
||||
pass
|
||||
else:
|
||||
register_database(APSWDatabase, 'apsw')
|
||||
|
||||
try:
|
||||
from playhouse.postgres_ext import PostgresqlExtDatabase
|
||||
except ImportError:
|
||||
pass
|
||||
else:
|
||||
register_database(PostgresqlExtDatabase, 'postgresext', 'postgresqlext')
|
|
@ -1,64 +0,0 @@
|
|||
try:
|
||||
import bz2
|
||||
except ImportError:
|
||||
bz2 = None
|
||||
try:
|
||||
import zlib
|
||||
except ImportError:
|
||||
zlib = None
|
||||
try:
|
||||
import cPickle as pickle
|
||||
except ImportError:
|
||||
import pickle
|
||||
import sys
|
||||
|
||||
from peewee import BlobField
|
||||
from peewee import buffer_type
|
||||
|
||||
|
||||
PY2 = sys.version_info[0] == 2
|
||||
|
||||
|
||||
class CompressedField(BlobField):
|
||||
ZLIB = 'zlib'
|
||||
BZ2 = 'bz2'
|
||||
algorithm_to_import = {
|
||||
ZLIB: zlib,
|
||||
BZ2: bz2,
|
||||
}
|
||||
|
||||
def __init__(self, compression_level=6, algorithm=ZLIB, *args,
|
||||
**kwargs):
|
||||
self.compression_level = compression_level
|
||||
if algorithm not in self.algorithm_to_import:
|
||||
raise ValueError('Unrecognized algorithm %s' % algorithm)
|
||||
compress_module = self.algorithm_to_import[algorithm]
|
||||
if compress_module is None:
|
||||
raise ValueError('Missing library required for %s.' % algorithm)
|
||||
|
||||
self.algorithm = algorithm
|
||||
self.compress = compress_module.compress
|
||||
self.decompress = compress_module.decompress
|
||||
super(CompressedField, self).__init__(*args, **kwargs)
|
||||
|
||||
def python_value(self, value):
|
||||
if value is not None:
|
||||
return self.decompress(value)
|
||||
|
||||
def db_value(self, value):
|
||||
if value is not None:
|
||||
return self._constructor(
|
||||
self.compress(value, self.compression_level))
|
||||
|
||||
|
||||
class PickleField(BlobField):
|
||||
def python_value(self, value):
|
||||
if value is not None:
|
||||
if isinstance(value, buffer_type):
|
||||
value = bytes(value)
|
||||
return pickle.loads(value)
|
||||
|
||||
def db_value(self, value):
|
||||
if value is not None:
|
||||
pickled = pickle.dumps(value, pickle.HIGHEST_PROTOCOL)
|
||||
return self._constructor(pickled)
|
|
@ -1,185 +0,0 @@
|
|||
import math
|
||||
import sys
|
||||
|
||||
from flask import abort
|
||||
from flask import render_template
|
||||
from flask import request
|
||||
from peewee import Database
|
||||
from peewee import DoesNotExist
|
||||
from peewee import Model
|
||||
from peewee import Proxy
|
||||
from peewee import SelectQuery
|
||||
from playhouse.db_url import connect as db_url_connect
|
||||
|
||||
|
||||
class PaginatedQuery(object):
|
||||
def __init__(self, query_or_model, paginate_by, page_var='page', page=None,
|
||||
check_bounds=False):
|
||||
self.paginate_by = paginate_by
|
||||
self.page_var = page_var
|
||||
self.page = page or None
|
||||
self.check_bounds = check_bounds
|
||||
|
||||
if isinstance(query_or_model, SelectQuery):
|
||||
self.query = query_or_model
|
||||
self.model = self.query.model
|
||||
else:
|
||||
self.model = query_or_model
|
||||
self.query = self.model.select()
|
||||
|
||||
def get_page(self):
|
||||
if self.page is not None:
|
||||
return self.page
|
||||
|
||||
curr_page = request.args.get(self.page_var)
|
||||
if curr_page and curr_page.isdigit():
|
||||
return max(1, int(curr_page))
|
||||
return 1
|
||||
|
||||
def get_page_count(self):
|
||||
if not hasattr(self, '_page_count'):
|
||||
self._page_count = int(math.ceil(
|
||||
float(self.query.count()) / self.paginate_by))
|
||||
return self._page_count
|
||||
|
||||
def get_object_list(self):
|
||||
if self.check_bounds and self.get_page() > self.get_page_count():
|
||||
abort(404)
|
||||
return self.query.paginate(self.get_page(), self.paginate_by)
|
||||
|
||||
|
||||
def get_object_or_404(query_or_model, *query):
|
||||
if not isinstance(query_or_model, SelectQuery):
|
||||
query_or_model = query_or_model.select()
|
||||
try:
|
||||
return query_or_model.where(*query).get()
|
||||
except DoesNotExist:
|
||||
abort(404)
|
||||
|
||||
def object_list(template_name, query, context_variable='object_list',
|
||||
paginate_by=20, page_var='page', page=None, check_bounds=True,
|
||||
**kwargs):
|
||||
paginated_query = PaginatedQuery(
|
||||
query,
|
||||
paginate_by=paginate_by,
|
||||
page_var=page_var,
|
||||
page=page,
|
||||
check_bounds=check_bounds)
|
||||
kwargs[context_variable] = paginated_query.get_object_list()
|
||||
return render_template(
|
||||
template_name,
|
||||
pagination=paginated_query,
|
||||
page=paginated_query.get_page(),
|
||||
**kwargs)
|
||||
|
||||
def get_current_url():
|
||||
if not request.query_string:
|
||||
return request.path
|
||||
return '%s?%s' % (request.path, request.query_string)
|
||||
|
||||
def get_next_url(default='/'):
|
||||
if request.args.get('next'):
|
||||
return request.args['next']
|
||||
elif request.form.get('next'):
|
||||
return request.form['next']
|
||||
return default
|
||||
|
||||
class FlaskDB(object):
|
||||
def __init__(self, app=None, database=None, model_class=Model):
|
||||
self.database = None # Reference to actual Peewee database instance.
|
||||
self.base_model_class = model_class
|
||||
self._app = app
|
||||
self._db = database # dict, url, Database, or None (default).
|
||||
if app is not None:
|
||||
self.init_app(app)
|
||||
|
||||
def init_app(self, app):
|
||||
self._app = app
|
||||
|
||||
if self._db is None:
|
||||
if 'DATABASE' in app.config:
|
||||
initial_db = app.config['DATABASE']
|
||||
elif 'DATABASE_URL' in app.config:
|
||||
initial_db = app.config['DATABASE_URL']
|
||||
else:
|
||||
raise ValueError('Missing required configuration data for '
|
||||
'database: DATABASE or DATABASE_URL.')
|
||||
else:
|
||||
initial_db = self._db
|
||||
|
||||
self._load_database(app, initial_db)
|
||||
self._register_handlers(app)
|
||||
|
||||
def _load_database(self, app, config_value):
|
||||
if isinstance(config_value, Database):
|
||||
database = config_value
|
||||
elif isinstance(config_value, dict):
|
||||
database = self._load_from_config_dict(dict(config_value))
|
||||
else:
|
||||
# Assume a database connection URL.
|
||||
database = db_url_connect(config_value)
|
||||
|
||||
if isinstance(self.database, Proxy):
|
||||
self.database.initialize(database)
|
||||
else:
|
||||
self.database = database
|
||||
|
||||
def _load_from_config_dict(self, config_dict):
|
||||
try:
|
||||
name = config_dict.pop('name')
|
||||
engine = config_dict.pop('engine')
|
||||
except KeyError:
|
||||
raise RuntimeError('DATABASE configuration must specify a '
|
||||
'`name` and `engine`.')
|
||||
|
||||
if '.' in engine:
|
||||
path, class_name = engine.rsplit('.', 1)
|
||||
else:
|
||||
path, class_name = 'peewee', engine
|
||||
|
||||
try:
|
||||
__import__(path)
|
||||
module = sys.modules[path]
|
||||
database_class = getattr(module, class_name)
|
||||
assert issubclass(database_class, Database)
|
||||
except ImportError:
|
||||
raise RuntimeError('Unable to import %s' % engine)
|
||||
except AttributeError:
|
||||
raise RuntimeError('Database engine not found %s' % engine)
|
||||
except AssertionError:
|
||||
raise RuntimeError('Database engine not a subclass of '
|
||||
'peewee.Database: %s' % engine)
|
||||
|
||||
return database_class(name, **config_dict)
|
||||
|
||||
def _register_handlers(self, app):
|
||||
app.before_request(self.connect_db)
|
||||
app.teardown_request(self.close_db)
|
||||
|
||||
def get_model_class(self):
|
||||
if self.database is None:
|
||||
raise RuntimeError('Database must be initialized.')
|
||||
|
||||
class BaseModel(self.base_model_class):
|
||||
class Meta:
|
||||
database = self.database
|
||||
|
||||
return BaseModel
|
||||
|
||||
@property
|
||||
def Model(self):
|
||||
if self._app is None:
|
||||
database = getattr(self, 'database', None)
|
||||
if database is None:
|
||||
self.database = Proxy()
|
||||
|
||||
if not hasattr(self, '_model_class'):
|
||||
self._model_class = self.get_model_class()
|
||||
return self._model_class
|
||||
|
||||
def connect_db(self):
|
||||
self.database.connect()
|
||||
|
||||
def close_db(self, exc):
|
||||
if not self.database.is_closed():
|
||||
self.database.close()
|
|
@ -1,53 +0,0 @@
|
|||
from peewee import ModelDescriptor
|
||||
|
||||
|
||||
# Hybrid methods/attributes, based on similar functionality in SQLAlchemy:
|
||||
# http://docs.sqlalchemy.org/en/improve_toc/orm/extensions/hybrid.html
|
||||
class hybrid_method(ModelDescriptor):
|
||||
def __init__(self, func, expr=None):
|
||||
self.func = func
|
||||
self.expr = expr or func
|
||||
|
||||
def __get__(self, instance, instance_type):
|
||||
if instance is None:
|
||||
return self.expr.__get__(instance_type, instance_type.__class__)
|
||||
return self.func.__get__(instance, instance_type)
|
||||
|
||||
def expression(self, expr):
|
||||
self.expr = expr
|
||||
return self
|
||||
|
||||
|
||||
class hybrid_property(ModelDescriptor):
|
||||
def __init__(self, fget, fset=None, fdel=None, expr=None):
|
||||
self.fget = fget
|
||||
self.fset = fset
|
||||
self.fdel = fdel
|
||||
self.expr = expr or fget
|
||||
|
||||
def __get__(self, instance, instance_type):
|
||||
if instance is None:
|
||||
return self.expr(instance_type)
|
||||
return self.fget(instance)
|
||||
|
||||
def __set__(self, instance, value):
|
||||
if self.fset is None:
|
||||
raise AttributeError('Cannot set attribute.')
|
||||
self.fset(instance, value)
|
||||
|
||||
def __delete__(self, instance):
|
||||
if self.fdel is None:
|
||||
raise AttributeError('Cannot delete attribute.')
|
||||
self.fdel(instance)
|
||||
|
||||
def setter(self, fset):
|
||||
self.fset = fset
|
||||
return self
|
||||
|
||||
def deleter(self, fdel):
|
||||
self.fdel = fdel
|
||||
return self
|
||||
|
||||
def expression(self, expr):
|
||||
self.expr = expr
|
||||
return self
|
|
@ -1,172 +0,0 @@
|
|||
import operator
|
||||
|
||||
from peewee import *
|
||||
from peewee import Expression
|
||||
from playhouse.fields import PickleField
|
||||
try:
|
||||
from playhouse.sqlite_ext import CSqliteExtDatabase as SqliteExtDatabase
|
||||
except ImportError:
|
||||
from playhouse.sqlite_ext import SqliteExtDatabase
|
||||
|
||||
|
||||
Sentinel = type('Sentinel', (object,), {})
|
||||
|
||||
|
||||
class KeyValue(object):
|
||||
"""
|
||||
Persistent dictionary.
|
||||
|
||||
:param Field key_field: field to use for key. Defaults to CharField.
|
||||
:param Field value_field: field to use for value. Defaults to PickleField.
|
||||
:param bool ordered: data should be returned in key-sorted order.
|
||||
:param Database database: database where key/value data is stored.
|
||||
:param str table_name: table name for data.
|
||||
"""
|
||||
def __init__(self, key_field=None, value_field=None, ordered=False,
|
||||
database=None, table_name='keyvalue'):
|
||||
if key_field is None:
|
||||
key_field = CharField(max_length=255, primary_key=True)
|
||||
if not key_field.primary_key:
|
||||
raise ValueError('key_field must have primary_key=True.')
|
||||
|
||||
if value_field is None:
|
||||
value_field = PickleField()
|
||||
|
||||
self._key_field = key_field
|
||||
self._value_field = value_field
|
||||
self._ordered = ordered
|
||||
self._database = database or SqliteExtDatabase(':memory:')
|
||||
self._table_name = table_name
|
||||
if isinstance(self._database, PostgresqlDatabase):
|
||||
self.upsert = self._postgres_upsert
|
||||
self.update = self._postgres_update
|
||||
else:
|
||||
self.upsert = self._upsert
|
||||
self.update = self._update
|
||||
|
||||
self.model = self.create_model()
|
||||
self.key = self.model.key
|
||||
self.value = self.model.value
|
||||
|
||||
# Ensure table exists.
|
||||
self.model.create_table()
|
||||
|
||||
def create_model(self):
|
||||
class KeyValue(Model):
|
||||
key = self._key_field
|
||||
value = self._value_field
|
||||
class Meta:
|
||||
database = self._database
|
||||
table_name = self._table_name
|
||||
return KeyValue
|
||||
|
||||
def query(self, *select):
|
||||
query = self.model.select(*select).tuples()
|
||||
if self._ordered:
|
||||
query = query.order_by(self.key)
|
||||
return query
|
||||
|
||||
def convert_expression(self, expr):
|
||||
if not isinstance(expr, Expression):
|
||||
return (self.key == expr), True
|
||||
return expr, False
|
||||
|
||||
def __contains__(self, key):
|
||||
expr, _ = self.convert_expression(key)
|
||||
return self.model.select().where(expr).exists()
|
||||
|
||||
def __len__(self):
|
||||
return len(self.model)
|
||||
|
||||
def __getitem__(self, expr):
|
||||
converted, is_single = self.convert_expression(expr)
|
||||
query = self.query(self.value).where(converted)
|
||||
item_getter = operator.itemgetter(0)
|
||||
result = [item_getter(row) for row in query]
|
||||
if len(result) == 0 and is_single:
|
||||
raise KeyError(expr)
|
||||
elif is_single:
|
||||
return result[0]
|
||||
return result
|
||||
|
||||
def _upsert(self, key, value):
|
||||
(self.model
|
||||
.insert(key=key, value=value)
|
||||
.on_conflict('replace')
|
||||
.execute())
|
||||
|
||||
def _postgres_upsert(self, key, value):
|
||||
(self.model
|
||||
.insert(key=key, value=value)
|
||||
.on_conflict(conflict_target=[self.key],
|
||||
preserve=[self.value])
|
||||
.execute())
|
||||
|
||||
def __setitem__(self, expr, value):
|
||||
if isinstance(expr, Expression):
|
||||
self.model.update(value=value).where(expr).execute()
|
||||
else:
|
||||
self.upsert(expr, value)
|
||||
|
||||
def __delitem__(self, expr):
|
||||
converted, _ = self.convert_expression(expr)
|
||||
self.model.delete().where(converted).execute()
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.query().execute())
|
||||
|
||||
def keys(self):
|
||||
return map(operator.itemgetter(0), self.query(self.key))
|
||||
|
||||
def values(self):
|
||||
return map(operator.itemgetter(0), self.query(self.value))
|
||||
|
||||
def items(self):
|
||||
return iter(self.query().execute())
|
||||
|
||||
def _update(self, __data=None, **mapping):
|
||||
if __data is not None:
|
||||
mapping.update(__data)
|
||||
return (self.model
|
||||
.insert_many(list(mapping.items()),
|
||||
fields=[self.key, self.value])
|
||||
.on_conflict('replace')
|
||||
.execute())
|
||||
|
||||
def _postgres_update(self, __data=None, **mapping):
|
||||
if __data is not None:
|
||||
mapping.update(__data)
|
||||
return (self.model
|
||||
.insert_many(list(mapping.items()),
|
||||
fields=[self.key, self.value])
|
||||
.on_conflict(conflict_target=[self.key],
|
||||
preserve=[self.value])
|
||||
.execute())
|
||||
|
||||
def get(self, key, default=None):
|
||||
try:
|
||||
return self[key]
|
||||
except KeyError:
|
||||
return default
|
||||
|
||||
def setdefault(self, key, default=None):
|
||||
try:
|
||||
return self[key]
|
||||
except KeyError:
|
||||
self[key] = default
|
||||
return default
|
||||
|
||||
def pop(self, key, default=Sentinel):
|
||||
with self._database.atomic():
|
||||
try:
|
||||
result = self[key]
|
||||
except KeyError:
|
||||
if default is Sentinel:
|
||||
raise
|
||||
return default
|
||||
del self[key]
|
||||
|
||||
return result
|
||||
|
||||
def clear(self):
|
||||
self.model.delete().execute()
|
|
@ -1,823 +0,0 @@
|
|||
"""
|
||||
Lightweight schema migrations.
|
||||
|
||||
NOTE: Currently tested with SQLite and Postgresql. MySQL may be missing some
|
||||
features.
|
||||
|
||||
Example Usage
|
||||
-------------
|
||||
|
||||
Instantiate a migrator:
|
||||
|
||||
# Postgres example:
|
||||
my_db = PostgresqlDatabase(...)
|
||||
migrator = PostgresqlMigrator(my_db)
|
||||
|
||||
# SQLite example:
|
||||
my_db = SqliteDatabase('my_database.db')
|
||||
migrator = SqliteMigrator(my_db)
|
||||
|
||||
Then you will use the `migrate` function to run various `Operation`s which
|
||||
are generated by the migrator:
|
||||
|
||||
migrate(
|
||||
migrator.add_column('some_table', 'column_name', CharField(default=''))
|
||||
)
|
||||
|
||||
Migrations are not run inside a transaction, so if you wish the migration to
|
||||
run in a transaction you will need to wrap the call to `migrate` in a
|
||||
transaction block, e.g.:
|
||||
|
||||
with my_db.transaction():
|
||||
migrate(...)
|
||||
|
||||
Supported Operations
|
||||
--------------------
|
||||
|
||||
Add new field(s) to an existing model:
|
||||
|
||||
# Create your field instances. For non-null fields you must specify a
|
||||
# default value.
|
||||
pubdate_field = DateTimeField(null=True)
|
||||
comment_field = TextField(default='')
|
||||
|
||||
# Run the migration, specifying the database table, field name and field.
|
||||
migrate(
|
||||
migrator.add_column('comment_tbl', 'pub_date', pubdate_field),
|
||||
migrator.add_column('comment_tbl', 'comment', comment_field),
|
||||
)
|
||||
|
||||
Renaming a field:
|
||||
|
||||
# Specify the table, original name of the column, and its new name.
|
||||
migrate(
|
||||
migrator.rename_column('story', 'pub_date', 'publish_date'),
|
||||
migrator.rename_column('story', 'mod_date', 'modified_date'),
|
||||
)
|
||||
|
||||
Dropping a field:
|
||||
|
||||
migrate(
|
||||
migrator.drop_column('story', 'some_old_field'),
|
||||
)
|
||||
|
||||
Making a field nullable or not nullable:
|
||||
|
||||
# Note that when making a field not null that field must not have any
|
||||
# NULL values present.
|
||||
migrate(
|
||||
# Make `pub_date` allow NULL values.
|
||||
migrator.drop_not_null('story', 'pub_date'),
|
||||
|
||||
# Prevent `modified_date` from containing NULL values.
|
||||
migrator.add_not_null('story', 'modified_date'),
|
||||
)
|
||||
|
||||
Renaming a table:
|
||||
|
||||
migrate(
|
||||
migrator.rename_table('story', 'stories_tbl'),
|
||||
)
|
||||
|
||||
Adding an index:
|
||||
|
||||
# Specify the table, column names, and whether the index should be
|
||||
# UNIQUE or not.
|
||||
migrate(
|
||||
# Create an index on the `pub_date` column.
|
||||
migrator.add_index('story', ('pub_date',), False),
|
||||
|
||||
# Create a multi-column index on the `pub_date` and `status` fields.
|
||||
migrator.add_index('story', ('pub_date', 'status'), False),
|
||||
|
||||
# Create a unique index on the category and title fields.
|
||||
migrator.add_index('story', ('category_id', 'title'), True),
|
||||
)
|
||||
|
||||
Dropping an index:
|
||||
|
||||
# Specify the index name.
|
||||
migrate(migrator.drop_index('story', 'story_pub_date_status'))
|
||||
|
||||
Adding or dropping table constraints:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# Add a CHECK() constraint to enforce the price cannot be negative.
|
||||
migrate(migrator.add_constraint(
|
||||
'products',
|
||||
'price_check',
|
||||
Check('price >= 0')))
|
||||
|
||||
# Remove the price check constraint.
|
||||
migrate(migrator.drop_constraint('products', 'price_check'))
|
||||
|
||||
# Add a UNIQUE constraint on the first and last names.
|
||||
migrate(migrator.add_unique('person', 'first_name', 'last_name'))
|
||||
"""
|
||||
from collections import namedtuple
|
||||
import functools
|
||||
import hashlib
|
||||
import re
|
||||
|
||||
from peewee import *
|
||||
from peewee import CommaNodeList
|
||||
from peewee import EnclosedNodeList
|
||||
from peewee import Entity
|
||||
from peewee import Expression
|
||||
from peewee import Node
|
||||
from peewee import NodeList
|
||||
from peewee import OP
|
||||
from peewee import callable_
|
||||
from peewee import sort_models
|
||||
from peewee import _truncate_constraint_name
|
||||
|
||||
|
||||
class Operation(object):
|
||||
"""Encapsulate a single schema altering operation."""
|
||||
def __init__(self, migrator, method, *args, **kwargs):
|
||||
self.migrator = migrator
|
||||
self.method = method
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
|
||||
def execute(self, node):
|
||||
self.migrator.database.execute(node)
|
||||
|
||||
def _handle_result(self, result):
|
||||
if isinstance(result, (Node, Context)):
|
||||
self.execute(result)
|
||||
elif isinstance(result, Operation):
|
||||
result.run()
|
||||
elif isinstance(result, (list, tuple)):
|
||||
for item in result:
|
||||
self._handle_result(item)
|
||||
|
||||
def run(self):
|
||||
kwargs = self.kwargs.copy()
|
||||
kwargs['with_context'] = True
|
||||
method = getattr(self.migrator, self.method)
|
||||
self._handle_result(method(*self.args, **kwargs))
|
||||
|
||||
|
||||
def operation(fn):
|
||||
@functools.wraps(fn)
|
||||
def inner(self, *args, **kwargs):
|
||||
with_context = kwargs.pop('with_context', False)
|
||||
if with_context:
|
||||
return fn(self, *args, **kwargs)
|
||||
return Operation(self, fn.__name__, *args, **kwargs)
|
||||
return inner
|
||||
|
||||
|
||||
def make_index_name(table_name, columns):
|
||||
index_name = '_'.join((table_name,) + tuple(columns))
|
||||
if len(index_name) > 64:
|
||||
index_hash = hashlib.md5(index_name.encode('utf-8')).hexdigest()
|
||||
index_name = '%s_%s' % (index_name[:56], index_hash[:7])
|
||||
return index_name
|
||||
|
||||
|
||||
class SchemaMigrator(object):
|
||||
explicit_create_foreign_key = False
|
||||
explicit_delete_foreign_key = False
|
||||
|
||||
def __init__(self, database):
|
||||
self.database = database
|
||||
|
||||
def make_context(self):
|
||||
return self.database.get_sql_context()
|
||||
|
||||
@classmethod
|
||||
def from_database(cls, database):
|
||||
if isinstance(database, PostgresqlDatabase):
|
||||
return PostgresqlMigrator(database)
|
||||
elif isinstance(database, MySQLDatabase):
|
||||
return MySQLMigrator(database)
|
||||
elif isinstance(database, SqliteDatabase):
|
||||
return SqliteMigrator(database)
|
||||
raise ValueError('Unsupported database: %s' % database)
|
||||
|
||||
@operation
|
||||
def apply_default(self, table, column_name, field):
|
||||
default = field.default
|
||||
if callable_(default):
|
||||
default = default()
|
||||
|
||||
return (self.make_context()
|
||||
.literal('UPDATE ')
|
||||
.sql(Entity(table))
|
||||
.literal(' SET ')
|
||||
.sql(Expression(
|
||||
Entity(column_name),
|
||||
OP.EQ,
|
||||
field.db_value(default),
|
||||
flat=True)))
|
||||
|
||||
def _alter_table(self, ctx, table):
|
||||
return ctx.literal('ALTER TABLE ').sql(Entity(table))
|
||||
|
||||
def _alter_column(self, ctx, table, column):
|
||||
return (self
|
||||
._alter_table(ctx, table)
|
||||
.literal(' ALTER COLUMN ')
|
||||
.sql(Entity(column)))
|
||||
|
||||
@operation
|
||||
def alter_add_column(self, table, column_name, field):
|
||||
# Make field null at first.
|
||||
ctx = self.make_context()
|
||||
field_null, field.null = field.null, True
|
||||
field.name = field.column_name = column_name
|
||||
(self
|
||||
._alter_table(ctx, table)
|
||||
.literal(' ADD COLUMN ')
|
||||
.sql(field.ddl(ctx)))
|
||||
|
||||
field.null = field_null
|
||||
if isinstance(field, ForeignKeyField):
|
||||
self.add_inline_fk_sql(ctx, field)
|
||||
return ctx
|
||||
|
||||
@operation
|
||||
def add_constraint(self, table, name, constraint):
|
||||
return (self
|
||||
._alter_table(self.make_context(), table)
|
||||
.literal(' ADD CONSTRAINT ')
|
||||
.sql(Entity(name))
|
||||
.literal(' ')
|
||||
.sql(constraint))
|
||||
|
||||
@operation
|
||||
def add_unique(self, table, *column_names):
|
||||
constraint_name = 'uniq_%s' % '_'.join(column_names)
|
||||
constraint = NodeList((
|
||||
SQL('UNIQUE'),
|
||||
EnclosedNodeList([Entity(column) for column in column_names])))
|
||||
return self.add_constraint(table, constraint_name, constraint)
|
||||
|
||||
@operation
|
||||
def drop_constraint(self, table, name):
|
||||
return (self
|
||||
._alter_table(self.make_context(), table)
|
||||
.literal(' DROP CONSTRAINT ')
|
||||
.sql(Entity(name)))
|
||||
|
||||
def add_inline_fk_sql(self, ctx, field):
|
||||
ctx = (ctx
|
||||
.literal(' REFERENCES ')
|
||||
.sql(Entity(field.rel_model._meta.table_name))
|
||||
.literal(' ')
|
||||
.sql(EnclosedNodeList((Entity(field.rel_field.column_name),))))
|
||||
if field.on_delete is not None:
|
||||
ctx = ctx.literal(' ON DELETE %s' % field.on_delete)
|
||||
if field.on_update is not None:
|
||||
ctx = ctx.literal(' ON UPDATE %s' % field.on_update)
|
||||
return ctx
|
||||
|
||||
@operation
|
||||
def add_foreign_key_constraint(self, table, column_name, rel, rel_column,
|
||||
on_delete=None, on_update=None):
|
||||
constraint = 'fk_%s_%s_refs_%s' % (table, column_name, rel)
|
||||
ctx = (self
|
||||
.make_context()
|
||||
.literal('ALTER TABLE ')
|
||||
.sql(Entity(table))
|
||||
.literal(' ADD CONSTRAINT ')
|
||||
.sql(Entity(_truncate_constraint_name(constraint)))
|
||||
.literal(' FOREIGN KEY ')
|
||||
.sql(EnclosedNodeList((Entity(column_name),)))
|
||||
.literal(' REFERENCES ')
|
||||
.sql(Entity(rel))
|
||||
.literal(' (')
|
||||
.sql(Entity(rel_column))
|
||||
.literal(')'))
|
||||
if on_delete is not None:
|
||||
ctx = ctx.literal(' ON DELETE %s' % on_delete)
|
||||
if on_update is not None:
|
||||
ctx = ctx.literal(' ON UPDATE %s' % on_update)
|
||||
return ctx
|
||||
|
||||
@operation
|
||||
def add_column(self, table, column_name, field):
|
||||
# Adding a column is complicated by the fact that if there are rows
|
||||
# present and the field is non-null, then we need to first add the
|
||||
# column as a nullable field, then set the value, then add a not null
|
||||
# constraint.
|
||||
if not field.null and field.default is None:
|
||||
raise ValueError('%s is not null but has no default' % column_name)
|
||||
|
||||
is_foreign_key = isinstance(field, ForeignKeyField)
|
||||
if is_foreign_key and not field.rel_field:
|
||||
raise ValueError('Foreign keys must specify a `field`.')
|
||||
|
||||
operations = [self.alter_add_column(table, column_name, field)]
|
||||
|
||||
# In the event the field is *not* nullable, update with the default
|
||||
# value and set not null.
|
||||
if not field.null:
|
||||
operations.extend([
|
||||
self.apply_default(table, column_name, field),
|
||||
self.add_not_null(table, column_name)])
|
||||
|
||||
if is_foreign_key and self.explicit_create_foreign_key:
|
||||
operations.append(
|
||||
self.add_foreign_key_constraint(
|
||||
table,
|
||||
column_name,
|
||||
field.rel_model._meta.table_name,
|
||||
field.rel_field.column_name,
|
||||
field.on_delete,
|
||||
field.on_update))
|
||||
|
||||
if field.index or field.unique:
|
||||
using = getattr(field, 'index_type', None)
|
||||
operations.append(self.add_index(table, (column_name,),
|
||||
field.unique, using))
|
||||
|
||||
return operations
|
||||
|
||||
@operation
|
||||
def drop_foreign_key_constraint(self, table, column_name):
|
||||
raise NotImplementedError
|
||||
|
||||
@operation
|
||||
def drop_column(self, table, column_name, cascade=True):
|
||||
ctx = self.make_context()
|
||||
(self._alter_table(ctx, table)
|
||||
.literal(' DROP COLUMN ')
|
||||
.sql(Entity(column_name)))
|
||||
|
||||
if cascade:
|
||||
ctx.literal(' CASCADE')
|
||||
|
||||
fk_columns = [
|
||||
foreign_key.column
|
||||
for foreign_key in self.database.get_foreign_keys(table)]
|
||||
if column_name in fk_columns and self.explicit_delete_foreign_key:
|
||||
return [self.drop_foreign_key_constraint(table, column_name), ctx]
|
||||
|
||||
return ctx
|
||||
|
||||
@operation
|
||||
def rename_column(self, table, old_name, new_name):
|
||||
return (self
|
||||
._alter_table(self.make_context(), table)
|
||||
.literal(' RENAME COLUMN ')
|
||||
.sql(Entity(old_name))
|
||||
.literal(' TO ')
|
||||
.sql(Entity(new_name)))
|
||||
|
||||
@operation
|
||||
def add_not_null(self, table, column):
|
||||
return (self
|
||||
._alter_column(self.make_context(), table, column)
|
||||
.literal(' SET NOT NULL'))
|
||||
|
||||
@operation
|
||||
def drop_not_null(self, table, column):
|
||||
return (self
|
||||
._alter_column(self.make_context(), table, column)
|
||||
.literal(' DROP NOT NULL'))
|
||||
|
||||
@operation
|
||||
def rename_table(self, old_name, new_name):
|
||||
return (self
|
||||
._alter_table(self.make_context(), old_name)
|
||||
.literal(' RENAME TO ')
|
||||
.sql(Entity(new_name)))
|
||||
|
||||
@operation
|
||||
def add_index(self, table, columns, unique=False, using=None):
|
||||
ctx = self.make_context()
|
||||
index_name = make_index_name(table, columns)
|
||||
table_obj = Table(table)
|
||||
cols = [getattr(table_obj.c, column) for column in columns]
|
||||
index = Index(index_name, table_obj, cols, unique=unique, using=using)
|
||||
return ctx.sql(index)
|
||||
|
||||
@operation
|
||||
def drop_index(self, table, index_name):
|
||||
return (self
|
||||
.make_context()
|
||||
.literal('DROP INDEX ')
|
||||
.sql(Entity(index_name)))
|
||||
|
||||
|
||||
class PostgresqlMigrator(SchemaMigrator):
|
||||
def _primary_key_columns(self, tbl):
|
||||
query = """
|
||||
SELECT pg_attribute.attname
|
||||
FROM pg_index, pg_class, pg_attribute
|
||||
WHERE
|
||||
pg_class.oid = '%s'::regclass AND
|
||||
indrelid = pg_class.oid AND
|
||||
pg_attribute.attrelid = pg_class.oid AND
|
||||
pg_attribute.attnum = any(pg_index.indkey) AND
|
||||
indisprimary;
|
||||
"""
|
||||
cursor = self.database.execute_sql(query % tbl)
|
||||
return [row[0] for row in cursor.fetchall()]
|
||||
|
||||
@operation
|
||||
def set_search_path(self, schema_name):
|
||||
return (self
|
||||
.make_context()
|
||||
.literal('SET search_path TO %s' % schema_name))
|
||||
|
||||
@operation
|
||||
def rename_table(self, old_name, new_name):
|
||||
pk_names = self._primary_key_columns(old_name)
|
||||
ParentClass = super(PostgresqlMigrator, self)
|
||||
|
||||
operations = [
|
||||
ParentClass.rename_table(old_name, new_name, with_context=True)]
|
||||
|
||||
if len(pk_names) == 1:
|
||||
# Check for existence of primary key sequence.
|
||||
seq_name = '%s_%s_seq' % (old_name, pk_names[0])
|
||||
query = """
|
||||
SELECT 1
|
||||
FROM information_schema.sequences
|
||||
WHERE LOWER(sequence_name) = LOWER(%s)
|
||||
"""
|
||||
cursor = self.database.execute_sql(query, (seq_name,))
|
||||
if bool(cursor.fetchone()):
|
||||
new_seq_name = '%s_%s_seq' % (new_name, pk_names[0])
|
||||
operations.append(ParentClass.rename_table(
|
||||
seq_name, new_seq_name))
|
||||
|
||||
return operations
|
||||
|
||||
|
||||
class MySQLColumn(namedtuple('_Column', ('name', 'definition', 'null', 'pk',
|
||||
'default', 'extra'))):
|
||||
@property
|
||||
def is_pk(self):
|
||||
return self.pk == 'PRI'
|
||||
|
||||
@property
|
||||
def is_unique(self):
|
||||
return self.pk == 'UNI'
|
||||
|
||||
@property
|
||||
def is_null(self):
|
||||
return self.null == 'YES'
|
||||
|
||||
def sql(self, column_name=None, is_null=None):
|
||||
if is_null is None:
|
||||
is_null = self.is_null
|
||||
if column_name is None:
|
||||
column_name = self.name
|
||||
parts = [
|
||||
Entity(column_name),
|
||||
SQL(self.definition)]
|
||||
if self.is_unique:
|
||||
parts.append(SQL('UNIQUE'))
|
||||
if is_null:
|
||||
parts.append(SQL('NULL'))
|
||||
else:
|
||||
parts.append(SQL('NOT NULL'))
|
||||
if self.is_pk:
|
||||
parts.append(SQL('PRIMARY KEY'))
|
||||
if self.extra:
|
||||
parts.append(SQL(self.extra))
|
||||
return NodeList(parts)
|
||||
|
||||
|
||||
class MySQLMigrator(SchemaMigrator):
|
||||
explicit_create_foreign_key = True
|
||||
explicit_delete_foreign_key = True
|
||||
|
||||
@operation
|
||||
def rename_table(self, old_name, new_name):
|
||||
return (self
|
||||
.make_context()
|
||||
.literal('RENAME TABLE ')
|
||||
.sql(Entity(old_name))
|
||||
.literal(' TO ')
|
||||
.sql(Entity(new_name)))
|
||||
|
||||
def _get_column_definition(self, table, column_name):
|
||||
cursor = self.database.execute_sql('DESCRIBE `%s`;' % table)
|
||||
rows = cursor.fetchall()
|
||||
for row in rows:
|
||||
column = MySQLColumn(*row)
|
||||
if column.name == column_name:
|
||||
return column
|
||||
return False
|
||||
|
||||
def get_foreign_key_constraint(self, table, column_name):
|
||||
cursor = self.database.execute_sql(
|
||||
('SELECT constraint_name '
|
||||
'FROM information_schema.key_column_usage WHERE '
|
||||
'table_schema = DATABASE() AND '
|
||||
'table_name = %s AND '
|
||||
'column_name = %s AND '
|
||||
'referenced_table_name IS NOT NULL AND '
|
||||
'referenced_column_name IS NOT NULL;'),
|
||||
(table, column_name))
|
||||
result = cursor.fetchone()
|
||||
if not result:
|
||||
raise AttributeError(
|
||||
'Unable to find foreign key constraint for '
|
||||
'"%s" on table "%s".' % (table, column_name))
|
||||
return result[0]
|
||||
|
||||
@operation
|
||||
def drop_foreign_key_constraint(self, table, column_name):
|
||||
fk_constraint = self.get_foreign_key_constraint(table, column_name)
|
||||
return (self
|
||||
.make_context()
|
||||
.literal('ALTER TABLE ')
|
||||
.sql(Entity(table))
|
||||
.literal(' DROP FOREIGN KEY ')
|
||||
.sql(Entity(fk_constraint)))
|
||||
|
||||
def add_inline_fk_sql(self, ctx, field):
|
||||
pass
|
||||
|
||||
@operation
|
||||
def add_not_null(self, table, column):
|
||||
column_def = self._get_column_definition(table, column)
|
||||
add_not_null = (self
|
||||
.make_context()
|
||||
.literal('ALTER TABLE ')
|
||||
.sql(Entity(table))
|
||||
.literal(' MODIFY ')
|
||||
.sql(column_def.sql(is_null=False)))
|
||||
|
||||
fk_objects = dict(
|
||||
(fk.column, fk)
|
||||
for fk in self.database.get_foreign_keys(table))
|
||||
if column not in fk_objects:
|
||||
return add_not_null
|
||||
|
||||
fk_metadata = fk_objects[column]
|
||||
return (self.drop_foreign_key_constraint(table, column),
|
||||
add_not_null,
|
||||
self.add_foreign_key_constraint(
|
||||
table,
|
||||
column,
|
||||
fk_metadata.dest_table,
|
||||
fk_metadata.dest_column))
|
||||
|
||||
@operation
|
||||
def drop_not_null(self, table, column):
|
||||
column = self._get_column_definition(table, column)
|
||||
if column.is_pk:
|
||||
raise ValueError('Primary keys can not be null')
|
||||
return (self
|
||||
.make_context()
|
||||
.literal('ALTER TABLE ')
|
||||
.sql(Entity(table))
|
||||
.literal(' MODIFY ')
|
||||
.sql(column.sql(is_null=True)))
|
||||
|
||||
@operation
|
||||
def rename_column(self, table, old_name, new_name):
|
||||
fk_objects = dict(
|
||||
(fk.column, fk)
|
||||
for fk in self.database.get_foreign_keys(table))
|
||||
is_foreign_key = old_name in fk_objects
|
||||
|
||||
column = self._get_column_definition(table, old_name)
|
||||
rename_ctx = (self
|
||||
.make_context()
|
||||
.literal('ALTER TABLE ')
|
||||
.sql(Entity(table))
|
||||
.literal(' CHANGE ')
|
||||
.sql(Entity(old_name))
|
||||
.literal(' ')
|
||||
.sql(column.sql(column_name=new_name)))
|
||||
if is_foreign_key:
|
||||
fk_metadata = fk_objects[old_name]
|
||||
return [
|
||||
self.drop_foreign_key_constraint(table, old_name),
|
||||
rename_ctx,
|
||||
self.add_foreign_key_constraint(
|
||||
table,
|
||||
new_name,
|
||||
fk_metadata.dest_table,
|
||||
fk_metadata.dest_column),
|
||||
]
|
||||
else:
|
||||
return rename_ctx
|
||||
|
||||
@operation
|
||||
def drop_index(self, table, index_name):
|
||||
return (self
|
||||
.make_context()
|
||||
.literal('DROP INDEX ')
|
||||
.sql(Entity(index_name))
|
||||
.literal(' ON ')
|
||||
.sql(Entity(table)))
|
||||
|
||||
|
||||
class SqliteMigrator(SchemaMigrator):
|
||||
"""
|
||||
SQLite supports a subset of ALTER TABLE queries, view the docs for the
|
||||
full details http://sqlite.org/lang_altertable.html
|
||||
"""
|
||||
column_re = re.compile('(.+?)\((.+)\)')
|
||||
column_split_re = re.compile(r'(?:[^,(]|\([^)]*\))+')
|
||||
column_name_re = re.compile('["`\']?([\w]+)')
|
||||
fk_re = re.compile('FOREIGN KEY\s+\("?([\w]+)"?\)\s+', re.I)
|
||||
|
||||
def _get_column_names(self, table):
|
||||
res = self.database.execute_sql('select * from "%s" limit 1' % table)
|
||||
return [item[0] for item in res.description]
|
||||
|
||||
def _get_create_table(self, table):
|
||||
res = self.database.execute_sql(
|
||||
('select name, sql from sqlite_master '
|
||||
'where type=? and LOWER(name)=?'),
|
||||
['table', table.lower()])
|
||||
return res.fetchone()
|
||||
|
||||
@operation
|
||||
def _update_column(self, table, column_to_update, fn):
|
||||
columns = set(column.name.lower()
|
||||
for column in self.database.get_columns(table))
|
||||
if column_to_update.lower() not in columns:
|
||||
raise ValueError('Column "%s" does not exist on "%s"' %
|
||||
(column_to_update, table))
|
||||
|
||||
# Get the SQL used to create the given table.
|
||||
table, create_table = self._get_create_table(table)
|
||||
|
||||
# Get the indexes and SQL to re-create indexes.
|
||||
indexes = self.database.get_indexes(table)
|
||||
|
||||
# Find any foreign keys we may need to remove.
|
||||
self.database.get_foreign_keys(table)
|
||||
|
||||
# Make sure the create_table does not contain any newlines or tabs,
|
||||
# allowing the regex to work correctly.
|
||||
create_table = re.sub(r'\s+', ' ', create_table)
|
||||
|
||||
# Parse out the `CREATE TABLE` and column list portions of the query.
|
||||
raw_create, raw_columns = self.column_re.search(create_table).groups()
|
||||
|
||||
# Clean up the individual column definitions.
|
||||
split_columns = self.column_split_re.findall(raw_columns)
|
||||
column_defs = [col.strip() for col in split_columns]
|
||||
|
||||
new_column_defs = []
|
||||
new_column_names = []
|
||||
original_column_names = []
|
||||
constraint_terms = ('foreign ', 'primary ', 'constraint ')
|
||||
|
||||
for column_def in column_defs:
|
||||
column_name, = self.column_name_re.match(column_def).groups()
|
||||
|
||||
if column_name == column_to_update:
|
||||
new_column_def = fn(column_name, column_def)
|
||||
if new_column_def:
|
||||
new_column_defs.append(new_column_def)
|
||||
original_column_names.append(column_name)
|
||||
column_name, = self.column_name_re.match(
|
||||
new_column_def).groups()
|
||||
new_column_names.append(column_name)
|
||||
else:
|
||||
new_column_defs.append(column_def)
|
||||
|
||||
# Avoid treating constraints as columns.
|
||||
if not column_def.lower().startswith(constraint_terms):
|
||||
new_column_names.append(column_name)
|
||||
original_column_names.append(column_name)
|
||||
|
||||
# Create a mapping of original columns to new columns.
|
||||
original_to_new = dict(zip(original_column_names, new_column_names))
|
||||
new_column = original_to_new.get(column_to_update)
|
||||
|
||||
fk_filter_fn = lambda column_def: column_def
|
||||
if not new_column:
|
||||
# Remove any foreign keys associated with this column.
|
||||
fk_filter_fn = lambda column_def: None
|
||||
elif new_column != column_to_update:
|
||||
# Update any foreign keys for this column.
|
||||
fk_filter_fn = lambda column_def: self.fk_re.sub(
|
||||
'FOREIGN KEY ("%s") ' % new_column,
|
||||
column_def)
|
||||
|
||||
cleaned_columns = []
|
||||
for column_def in new_column_defs:
|
||||
match = self.fk_re.match(column_def)
|
||||
if match is not None and match.groups()[0] == column_to_update:
|
||||
column_def = fk_filter_fn(column_def)
|
||||
if column_def:
|
||||
cleaned_columns.append(column_def)
|
||||
|
||||
# Update the name of the new CREATE TABLE query.
|
||||
temp_table = table + '__tmp__'
|
||||
rgx = re.compile('("?)%s("?)' % table, re.I)
|
||||
create = rgx.sub(
|
||||
'\\1%s\\2' % temp_table,
|
||||
raw_create)
|
||||
|
||||
# Create the new table.
|
||||
columns = ', '.join(cleaned_columns)
|
||||
queries = [
|
||||
NodeList([SQL('DROP TABLE IF EXISTS'), Entity(temp_table)]),
|
||||
SQL('%s (%s)' % (create.strip(), columns))]
|
||||
|
||||
# Populate new table.
|
||||
populate_table = NodeList((
|
||||
SQL('INSERT INTO'),
|
||||
Entity(temp_table),
|
||||
EnclosedNodeList([Entity(col) for col in new_column_names]),
|
||||
SQL('SELECT'),
|
||||
CommaNodeList([Entity(col) for col in original_column_names]),
|
||||
SQL('FROM'),
|
||||
Entity(table)))
|
||||
drop_original = NodeList([SQL('DROP TABLE'), Entity(table)])
|
||||
|
||||
# Drop existing table and rename temp table.
|
||||
queries += [
|
||||
populate_table,
|
||||
drop_original,
|
||||
self.rename_table(temp_table, table)]
|
||||
|
||||
# Re-create user-defined indexes. User-defined indexes will have a
|
||||
# non-empty SQL attribute.
|
||||
for index in filter(lambda idx: idx.sql, indexes):
|
||||
if column_to_update not in index.columns:
|
||||
queries.append(SQL(index.sql))
|
||||
elif new_column:
|
||||
sql = self._fix_index(index.sql, column_to_update, new_column)
|
||||
if sql is not None:
|
||||
queries.append(SQL(sql))
|
||||
|
||||
return queries
|
||||
|
||||
def _fix_index(self, sql, column_to_update, new_column):
|
||||
# Split on the name of the column to update. If it splits into two
|
||||
# pieces, then there's no ambiguity and we can simply replace the
|
||||
# old with the new.
|
||||
parts = sql.split(column_to_update)
|
||||
if len(parts) == 2:
|
||||
return sql.replace(column_to_update, new_column)
|
||||
|
||||
# Find the list of columns in the index expression.
|
||||
lhs, rhs = sql.rsplit('(', 1)
|
||||
|
||||
# Apply the same "split in two" logic to the column list portion of
|
||||
# the query.
|
||||
if len(rhs.split(column_to_update)) == 2:
|
||||
return '%s(%s' % (lhs, rhs.replace(column_to_update, new_column))
|
||||
|
||||
# Strip off the trailing parentheses and go through each column.
|
||||
parts = rhs.rsplit(')', 1)[0].split(',')
|
||||
columns = [part.strip('"`[]\' ') for part in parts]
|
||||
|
||||
# `columns` looks something like: ['status', 'timestamp" DESC']
|
||||
# https://www.sqlite.org/lang_keywords.html
|
||||
# Strip out any junk after the column name.
|
||||
clean = []
|
||||
for column in columns:
|
||||
if re.match('%s(?:[\'"`\]]?\s|$)' % column_to_update, column):
|
||||
column = new_column + column[len(column_to_update):]
|
||||
clean.append(column)
|
||||
|
||||
return '%s(%s)' % (lhs, ', '.join('"%s"' % c for c in clean))
|
||||
|
||||
@operation
|
||||
def drop_column(self, table, column_name, cascade=True):
|
||||
return self._update_column(table, column_name, lambda a, b: None)
|
||||
|
||||
@operation
|
||||
def rename_column(self, table, old_name, new_name):
|
||||
def _rename(column_name, column_def):
|
||||
return column_def.replace(column_name, new_name)
|
||||
return self._update_column(table, old_name, _rename)
|
||||
|
||||
@operation
|
||||
def add_not_null(self, table, column):
|
||||
def _add_not_null(column_name, column_def):
|
||||
return column_def + ' NOT NULL'
|
||||
return self._update_column(table, column, _add_not_null)
|
||||
|
||||
@operation
|
||||
def drop_not_null(self, table, column):
|
||||
def _drop_not_null(column_name, column_def):
|
||||
return column_def.replace('NOT NULL', '')
|
||||
return self._update_column(table, column, _drop_not_null)
|
||||
|
||||
@operation
|
||||
def add_constraint(self, table, name, constraint):
|
||||
raise NotImplementedError
|
||||
|
||||
@operation
|
||||
def drop_constraint(self, table, name):
|
||||
raise NotImplementedError
|
||||
|
||||
@operation
|
||||
def add_foreign_key_constraint(self, table, column_name, field,
|
||||
on_delete=None, on_update=None):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def migrate(*operations, **kwargs):
|
||||
for operation in operations:
|
||||
operation.run()
|
|
@ -1,49 +0,0 @@
|
|||
import json
|
||||
|
||||
try:
|
||||
import mysql.connector as mysql_connector
|
||||
except ImportError:
|
||||
mysql_connector = None
|
||||
|
||||
from peewee import ImproperlyConfigured
|
||||
from peewee import MySQLDatabase
|
||||
from peewee import NodeList
|
||||
from peewee import SQL
|
||||
from peewee import TextField
|
||||
from peewee import fn
|
||||
|
||||
|
||||
class MySQLConnectorDatabase(MySQLDatabase):
|
||||
def _connect(self):
|
||||
if mysql_connector is None:
|
||||
raise ImproperlyConfigured('MySQL connector not installed!')
|
||||
return mysql_connector.connect(db=self.database, **self.connect_params)
|
||||
|
||||
def cursor(self, commit=None):
|
||||
if self.is_closed():
|
||||
if self.autoconnect:
|
||||
self.connect()
|
||||
else:
|
||||
raise InterfaceError('Error, database connection not opened.')
|
||||
return self._state.conn.cursor(buffered=True)
|
||||
|
||||
|
||||
class JSONField(TextField):
|
||||
field_type = 'JSON'
|
||||
|
||||
def db_value(self, value):
|
||||
if value is not None:
|
||||
return json.dumps(value)
|
||||
|
||||
def python_value(self, value):
|
||||
if value is not None:
|
||||
return json.loads(value)
|
||||
|
||||
|
||||
def Match(columns, expr, modifier=None):
|
||||
if isinstance(columns, (list, tuple)):
|
||||
match = fn.MATCH(*columns) # Tuple of one or more columns / fields.
|
||||
else:
|
||||
match = fn.MATCH(columns) # Single column / field.
|
||||
args = expr if modifier is None else NodeList((expr, SQL(modifier)))
|
||||
return NodeList((match, fn.AGAINST(args)))
|
|
@ -1,318 +0,0 @@
|
|||
"""
|
||||
Lightweight connection pooling for peewee.
|
||||
|
||||
In a multi-threaded application, up to `max_connections` will be opened. Each
|
||||
thread (or, if using gevent, greenlet) will have it's own connection.
|
||||
|
||||
In a single-threaded application, only one connection will be created. It will
|
||||
be continually recycled until either it exceeds the stale timeout or is closed
|
||||
explicitly (using `.manual_close()`).
|
||||
|
||||
By default, all your application needs to do is ensure that connections are
|
||||
closed when you are finished with them, and they will be returned to the pool.
|
||||
For web applications, this typically means that at the beginning of a request,
|
||||
you will open a connection, and when you return a response, you will close the
|
||||
connection.
|
||||
|
||||
Simple Postgres pool example code:
|
||||
|
||||
# Use the special postgresql extensions.
|
||||
from playhouse.pool import PooledPostgresqlExtDatabase
|
||||
|
||||
db = PooledPostgresqlExtDatabase(
|
||||
'my_app',
|
||||
max_connections=32,
|
||||
stale_timeout=300, # 5 minutes.
|
||||
user='postgres')
|
||||
|
||||
class BaseModel(Model):
|
||||
class Meta:
|
||||
database = db
|
||||
|
||||
That's it!
|
||||
"""
|
||||
import heapq
|
||||
import logging
|
||||
import random
|
||||
import time
|
||||
from collections import namedtuple
|
||||
from itertools import chain
|
||||
|
||||
try:
|
||||
from psycopg2.extensions import TRANSACTION_STATUS_IDLE
|
||||
from psycopg2.extensions import TRANSACTION_STATUS_INERROR
|
||||
from psycopg2.extensions import TRANSACTION_STATUS_UNKNOWN
|
||||
except ImportError:
|
||||
TRANSACTION_STATUS_IDLE = \
|
||||
TRANSACTION_STATUS_INERROR = \
|
||||
TRANSACTION_STATUS_UNKNOWN = None
|
||||
|
||||
from peewee import MySQLDatabase
|
||||
from peewee import PostgresqlDatabase
|
||||
from peewee import SqliteDatabase
|
||||
|
||||
logger = logging.getLogger('peewee.pool')
|
||||
|
||||
|
||||
def make_int(val):
|
||||
if val is not None and not isinstance(val, (int, float)):
|
||||
return int(val)
|
||||
return val
|
||||
|
||||
|
||||
class MaxConnectionsExceeded(ValueError): pass
|
||||
|
||||
|
||||
PoolConnection = namedtuple('PoolConnection', ('timestamp', 'connection',
|
||||
'checked_out'))
|
||||
|
||||
|
||||
class PooledDatabase(object):
|
||||
def __init__(self, database, max_connections=20, stale_timeout=None,
|
||||
timeout=None, **kwargs):
|
||||
self._max_connections = make_int(max_connections)
|
||||
self._stale_timeout = make_int(stale_timeout)
|
||||
self._wait_timeout = make_int(timeout)
|
||||
if self._wait_timeout == 0:
|
||||
self._wait_timeout = float('inf')
|
||||
|
||||
# Available / idle connections stored in a heap, sorted oldest first.
|
||||
self._connections = []
|
||||
|
||||
# Mapping of connection id to PoolConnection. Ordinarily we would want
|
||||
# to use something like a WeakKeyDictionary, but Python typically won't
|
||||
# allow us to create weak references to connection objects.
|
||||
self._in_use = {}
|
||||
|
||||
# Use the memory address of the connection as the key in the event the
|
||||
# connection object is not hashable. Connections will not get
|
||||
# garbage-collected, however, because a reference to them will persist
|
||||
# in "_in_use" as long as the conn has not been closed.
|
||||
self.conn_key = id
|
||||
|
||||
super(PooledDatabase, self).__init__(database, **kwargs)
|
||||
|
||||
def init(self, database, max_connections=None, stale_timeout=None,
|
||||
timeout=None, **connect_kwargs):
|
||||
super(PooledDatabase, self).init(database, **connect_kwargs)
|
||||
if max_connections is not None:
|
||||
self._max_connections = make_int(max_connections)
|
||||
if stale_timeout is not None:
|
||||
self._stale_timeout = make_int(stale_timeout)
|
||||
if timeout is not None:
|
||||
self._wait_timeout = make_int(timeout)
|
||||
if self._wait_timeout == 0:
|
||||
self._wait_timeout = float('inf')
|
||||
|
||||
def connect(self, reuse_if_open=False):
|
||||
if not self._wait_timeout:
|
||||
return super(PooledDatabase, self).connect(reuse_if_open)
|
||||
|
||||
expires = time.time() + self._wait_timeout
|
||||
while expires > time.time():
|
||||
try:
|
||||
ret = super(PooledDatabase, self).connect(reuse_if_open)
|
||||
except MaxConnectionsExceeded:
|
||||
time.sleep(0.1)
|
||||
else:
|
||||
return ret
|
||||
raise MaxConnectionsExceeded('Max connections exceeded, timed out '
|
||||
'attempting to connect.')
|
||||
|
||||
def _connect(self):
|
||||
while True:
|
||||
try:
|
||||
# Remove the oldest connection from the heap.
|
||||
ts, conn = heapq.heappop(self._connections)
|
||||
key = self.conn_key(conn)
|
||||
except IndexError:
|
||||
ts = conn = None
|
||||
logger.debug('No connection available in pool.')
|
||||
break
|
||||
else:
|
||||
if self._is_closed(conn):
|
||||
# This connecton was closed, but since it was not stale
|
||||
# it got added back to the queue of available conns. We
|
||||
# then closed it and marked it as explicitly closed, so
|
||||
# it's safe to throw it away now.
|
||||
# (Because Database.close() calls Database._close()).
|
||||
logger.debug('Connection %s was closed.', key)
|
||||
ts = conn = None
|
||||
elif self._stale_timeout and self._is_stale(ts):
|
||||
# If we are attempting to check out a stale connection,
|
||||
# then close it. We don't need to mark it in the "closed"
|
||||
# set, because it is not in the list of available conns
|
||||
# anymore.
|
||||
logger.debug('Connection %s was stale, closing.', key)
|
||||
self._close(conn, True)
|
||||
ts = conn = None
|
||||
else:
|
||||
break
|
||||
|
||||
if conn is None:
|
||||
if self._max_connections and (
|
||||
len(self._in_use) >= self._max_connections):
|
||||
raise MaxConnectionsExceeded('Exceeded maximum connections.')
|
||||
conn = super(PooledDatabase, self)._connect()
|
||||
ts = time.time() - random.random() / 1000
|
||||
key = self.conn_key(conn)
|
||||
logger.debug('Created new connection %s.', key)
|
||||
|
||||
self._in_use[key] = PoolConnection(ts, conn, time.time())
|
||||
return conn
|
||||
|
||||
def _is_stale(self, timestamp):
|
||||
# Called on check-out and check-in to ensure the connection has
|
||||
# not outlived the stale timeout.
|
||||
return (time.time() - timestamp) > self._stale_timeout
|
||||
|
||||
def _is_closed(self, conn):
|
||||
return False
|
||||
|
||||
def _can_reuse(self, conn):
|
||||
# Called on check-in to make sure the connection can be re-used.
|
||||
return True
|
||||
|
||||
def _close(self, conn, close_conn=False):
|
||||
key = self.conn_key(conn)
|
||||
if close_conn:
|
||||
super(PooledDatabase, self)._close(conn)
|
||||
elif key in self._in_use:
|
||||
pool_conn = self._in_use.pop(key)
|
||||
if self._stale_timeout and self._is_stale(pool_conn.timestamp):
|
||||
logger.debug('Closing stale connection %s.', key)
|
||||
super(PooledDatabase, self)._close(conn)
|
||||
elif self._can_reuse(conn):
|
||||
logger.debug('Returning %s to pool.', key)
|
||||
heapq.heappush(self._connections, (pool_conn.timestamp, conn))
|
||||
else:
|
||||
logger.debug('Closed %s.', key)
|
||||
|
||||
def manual_close(self):
|
||||
"""
|
||||
Close the underlying connection without returning it to the pool.
|
||||
"""
|
||||
if self.is_closed():
|
||||
return False
|
||||
|
||||
# Obtain reference to the connection in-use by the calling thread.
|
||||
conn = self.connection()
|
||||
|
||||
# A connection will only be re-added to the available list if it is
|
||||
# marked as "in use" at the time it is closed. We will explicitly
|
||||
# remove it from the "in use" list, call "close()" for the
|
||||
# side-effects, and then explicitly close the connection.
|
||||
self._in_use.pop(self.conn_key(conn), None)
|
||||
self.close()
|
||||
self._close(conn, close_conn=True)
|
||||
|
||||
def close_idle(self):
|
||||
# Close any open connections that are not currently in-use.
|
||||
with self._lock:
|
||||
for _, conn in self._connections:
|
||||
self._close(conn, close_conn=True)
|
||||
self._connections = []
|
||||
|
||||
def close_stale(self, age=600):
|
||||
# Close any connections that are in-use but were checked out quite some
|
||||
# time ago and can be considered stale.
|
||||
with self._lock:
|
||||
in_use = {}
|
||||
cutoff = time.time() - age
|
||||
n = 0
|
||||
for key, pool_conn in self._in_use.items():
|
||||
if pool_conn.checked_out < cutoff:
|
||||
self._close(pool_conn.connection, close_conn=True)
|
||||
n += 1
|
||||
else:
|
||||
in_use[key] = pool_conn
|
||||
self._in_use = in_use
|
||||
return n
|
||||
|
||||
def close_all(self):
|
||||
# Close all connections -- available and in-use. Warning: may break any
|
||||
# active connections used by other threads.
|
||||
self.close()
|
||||
with self._lock:
|
||||
for _, conn in self._connections:
|
||||
self._close(conn, close_conn=True)
|
||||
for pool_conn in self._in_use.values():
|
||||
self._close(pool_conn.connection, close_conn=True)
|
||||
self._connections = []
|
||||
self._in_use = {}
|
||||
|
||||
|
||||
class PooledMySQLDatabase(PooledDatabase, MySQLDatabase):
|
||||
def _is_closed(self, conn):
|
||||
try:
|
||||
conn.ping(False)
|
||||
except:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
class _PooledPostgresqlDatabase(PooledDatabase):
|
||||
def _is_closed(self, conn):
|
||||
if conn.closed:
|
||||
return True
|
||||
|
||||
txn_status = conn.get_transaction_status()
|
||||
if txn_status == TRANSACTION_STATUS_UNKNOWN:
|
||||
return True
|
||||
elif txn_status != TRANSACTION_STATUS_IDLE:
|
||||
conn.rollback()
|
||||
return False
|
||||
|
||||
def _can_reuse(self, conn):
|
||||
txn_status = conn.get_transaction_status()
|
||||
# Do not return connection in an error state, as subsequent queries
|
||||
# will all fail. If the status is unknown then we lost the connection
|
||||
# to the server and the connection should not be re-used.
|
||||
if txn_status == TRANSACTION_STATUS_UNKNOWN:
|
||||
return False
|
||||
elif txn_status == TRANSACTION_STATUS_INERROR:
|
||||
conn.reset()
|
||||
elif txn_status != TRANSACTION_STATUS_IDLE:
|
||||
conn.rollback()
|
||||
return True
|
||||
|
||||
class PooledPostgresqlDatabase(_PooledPostgresqlDatabase, PostgresqlDatabase):
|
||||
pass
|
||||
|
||||
try:
|
||||
from playhouse.postgres_ext import PostgresqlExtDatabase
|
||||
|
||||
class PooledPostgresqlExtDatabase(_PooledPostgresqlDatabase, PostgresqlExtDatabase):
|
||||
pass
|
||||
except ImportError:
|
||||
PooledPostgresqlExtDatabase = None
|
||||
|
||||
|
||||
class _PooledSqliteDatabase(PooledDatabase):
|
||||
def _is_closed(self, conn):
|
||||
try:
|
||||
conn.total_changes
|
||||
except:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
class PooledSqliteDatabase(_PooledSqliteDatabase, SqliteDatabase):
|
||||
pass
|
||||
|
||||
try:
|
||||
from playhouse.sqlite_ext import SqliteExtDatabase
|
||||
|
||||
class PooledSqliteExtDatabase(_PooledSqliteDatabase, SqliteExtDatabase):
|
||||
pass
|
||||
except ImportError:
|
||||
PooledSqliteExtDatabase = None
|
||||
|
||||
try:
|
||||
from playhouse.sqlite_ext import CSqliteExtDatabase
|
||||
|
||||
class PooledCSqliteExtDatabase(_PooledSqliteDatabase, CSqliteExtDatabase):
|
||||
pass
|
||||
except ImportError:
|
||||
PooledCSqliteExtDatabase = None
|
|
@ -1,474 +0,0 @@
|
|||
"""
|
||||
Collection of postgres-specific extensions, currently including:
|
||||
|
||||
* Support for hstore, a key/value type storage
|
||||
"""
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
|
||||
from peewee import *
|
||||
from peewee import ColumnBase
|
||||
from peewee import Expression
|
||||
from peewee import Node
|
||||
from peewee import NodeList
|
||||
from peewee import SENTINEL
|
||||
from peewee import __exception_wrapper__
|
||||
|
||||
try:
|
||||
from psycopg2cffi import compat
|
||||
compat.register()
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
from psycopg2.extras import register_hstore
|
||||
try:
|
||||
from psycopg2.extras import Json
|
||||
except:
|
||||
Json = None
|
||||
|
||||
|
||||
logger = logging.getLogger('peewee')
|
||||
|
||||
|
||||
HCONTAINS_DICT = '@>'
|
||||
HCONTAINS_KEYS = '?&'
|
||||
HCONTAINS_KEY = '?'
|
||||
HCONTAINS_ANY_KEY = '?|'
|
||||
HKEY = '->'
|
||||
HUPDATE = '||'
|
||||
ACONTAINS = '@>'
|
||||
ACONTAINS_ANY = '&&'
|
||||
TS_MATCH = '@@'
|
||||
JSONB_CONTAINS = '@>'
|
||||
JSONB_CONTAINED_BY = '<@'
|
||||
JSONB_CONTAINS_KEY = '?'
|
||||
JSONB_CONTAINS_ANY_KEY = '?|'
|
||||
JSONB_CONTAINS_ALL_KEYS = '?&'
|
||||
JSONB_EXISTS = '?'
|
||||
JSONB_REMOVE = '-'
|
||||
|
||||
|
||||
class _LookupNode(ColumnBase):
|
||||
def __init__(self, node, parts):
|
||||
self.node = node
|
||||
self.parts = parts
|
||||
super(_LookupNode, self).__init__()
|
||||
|
||||
def clone(self):
|
||||
return type(self)(self.node, list(self.parts))
|
||||
|
||||
|
||||
class _JsonLookupBase(_LookupNode):
|
||||
def __init__(self, node, parts, as_json=False):
|
||||
super(_JsonLookupBase, self).__init__(node, parts)
|
||||
self._as_json = as_json
|
||||
|
||||
def clone(self):
|
||||
return type(self)(self.node, list(self.parts), self._as_json)
|
||||
|
||||
@Node.copy
|
||||
def as_json(self, as_json=True):
|
||||
self._as_json = as_json
|
||||
|
||||
def concat(self, rhs):
|
||||
return Expression(self.as_json(True), OP.CONCAT, Json(rhs))
|
||||
|
||||
def contains(self, other):
|
||||
clone = self.as_json(True)
|
||||
if isinstance(other, (list, dict)):
|
||||
return Expression(clone, JSONB_CONTAINS, Json(other))
|
||||
return Expression(clone, JSONB_EXISTS, other)
|
||||
|
||||
def contains_any(self, *keys):
|
||||
return Expression(
|
||||
self.as_json(True),
|
||||
JSONB_CONTAINS_ANY_KEY,
|
||||
Value(list(keys), unpack=False))
|
||||
|
||||
def contains_all(self, *keys):
|
||||
return Expression(
|
||||
self.as_json(True),
|
||||
JSONB_CONTAINS_ALL_KEYS,
|
||||
Value(list(keys), unpack=False))
|
||||
|
||||
def has_key(self, key):
|
||||
return Expression(self.as_json(True), JSONB_CONTAINS_KEY, key)
|
||||
|
||||
|
||||
class JsonLookup(_JsonLookupBase):
|
||||
def __getitem__(self, value):
|
||||
return JsonLookup(self.node, self.parts + [value], self._as_json)
|
||||
|
||||
def __sql__(self, ctx):
|
||||
ctx.sql(self.node)
|
||||
for part in self.parts[:-1]:
|
||||
ctx.literal('->').sql(part)
|
||||
if self.parts:
|
||||
(ctx
|
||||
.literal('->' if self._as_json else '->>')
|
||||
.sql(self.parts[-1]))
|
||||
|
||||
return ctx
|
||||
|
||||
|
||||
class JsonPath(_JsonLookupBase):
|
||||
def __sql__(self, ctx):
|
||||
return (ctx
|
||||
.sql(self.node)
|
||||
.literal('#>' if self._as_json else '#>>')
|
||||
.sql(Value('{%s}' % ','.join(map(str, self.parts)))))
|
||||
|
||||
|
||||
class ObjectSlice(_LookupNode):
|
||||
@classmethod
|
||||
def create(cls, node, value):
|
||||
if isinstance(value, slice):
|
||||
parts = [value.start or 0, value.stop or 0]
|
||||
elif isinstance(value, int):
|
||||
parts = [value]
|
||||
else:
|
||||
parts = map(int, value.split(':'))
|
||||
return cls(node, parts)
|
||||
|
||||
def __sql__(self, ctx):
|
||||
return (ctx
|
||||
.sql(self.node)
|
||||
.literal('[%s]' % ':'.join(str(p + 1) for p in self.parts)))
|
||||
|
||||
def __getitem__(self, value):
|
||||
return ObjectSlice.create(self, value)
|
||||
|
||||
|
||||
class IndexedFieldMixin(object):
|
||||
default_index_type = 'GIN'
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
kwargs.setdefault('index', True) # By default, use an index.
|
||||
super(IndexedFieldMixin, self).__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class ArrayField(IndexedFieldMixin, Field):
|
||||
passthrough = True
|
||||
|
||||
def __init__(self, field_class=IntegerField, field_kwargs=None,
|
||||
dimensions=1, convert_values=False, *args, **kwargs):
|
||||
self.__field = field_class(**(field_kwargs or {}))
|
||||
self.dimensions = dimensions
|
||||
self.convert_values = convert_values
|
||||
self.field_type = self.__field.field_type
|
||||
super(ArrayField, self).__init__(*args, **kwargs)
|
||||
|
||||
def bind(self, model, name, set_attribute=True):
|
||||
ret = super(ArrayField, self).bind(model, name, set_attribute)
|
||||
self.__field.bind(model, '__array_%s' % name, False)
|
||||
return ret
|
||||
|
||||
def ddl_datatype(self, ctx):
|
||||
data_type = self.__field.ddl_datatype(ctx)
|
||||
return NodeList((data_type, SQL('[]' * self.dimensions)), glue='')
|
||||
|
||||
def db_value(self, value):
|
||||
if value is None or isinstance(value, Node):
|
||||
return value
|
||||
elif self.convert_values:
|
||||
return self._process(self.__field.db_value, value, self.dimensions)
|
||||
else:
|
||||
return value if isinstance(value, list) else list(value)
|
||||
|
||||
def python_value(self, value):
|
||||
if self.convert_values and value is not None:
|
||||
conv = self.__field.python_value
|
||||
if isinstance(value, list):
|
||||
return self._process(conv, value, self.dimensions)
|
||||
else:
|
||||
return conv(value)
|
||||
else:
|
||||
return value
|
||||
|
||||
def _process(self, conv, value, dimensions):
|
||||
dimensions -= 1
|
||||
if dimensions == 0:
|
||||
return [conv(v) for v in value]
|
||||
else:
|
||||
return [self._process(conv, v, dimensions) for v in value]
|
||||
|
||||
def __getitem__(self, value):
|
||||
return ObjectSlice.create(self, value)
|
||||
|
||||
def _e(op):
|
||||
def inner(self, rhs):
|
||||
return Expression(self, op, ArrayValue(self, rhs))
|
||||
return inner
|
||||
__eq__ = _e(OP.EQ)
|
||||
__ne__ = _e(OP.NE)
|
||||
__gt__ = _e(OP.GT)
|
||||
__ge__ = _e(OP.GTE)
|
||||
__lt__ = _e(OP.LT)
|
||||
__le__ = _e(OP.LTE)
|
||||
__hash__ = Field.__hash__
|
||||
|
||||
def contains(self, *items):
|
||||
return Expression(self, ACONTAINS, ArrayValue(self, items))
|
||||
|
||||
def contains_any(self, *items):
|
||||
return Expression(self, ACONTAINS_ANY, ArrayValue(self, items))
|
||||
|
||||
|
||||
class ArrayValue(Node):
|
||||
def __init__(self, field, value):
|
||||
self.field = field
|
||||
self.value = value
|
||||
|
||||
def __sql__(self, ctx):
|
||||
return (ctx
|
||||
.sql(Value(self.value, unpack=False))
|
||||
.literal('::')
|
||||
.sql(self.field.ddl_datatype(ctx)))
|
||||
|
||||
|
||||
class DateTimeTZField(DateTimeField):
|
||||
field_type = 'TIMESTAMPTZ'
|
||||
|
||||
|
||||
class HStoreField(IndexedFieldMixin, Field):
|
||||
field_type = 'HSTORE'
|
||||
__hash__ = Field.__hash__
|
||||
|
||||
def __getitem__(self, key):
|
||||
return Expression(self, HKEY, Value(key))
|
||||
|
||||
def keys(self):
|
||||
return fn.akeys(self)
|
||||
|
||||
def values(self):
|
||||
return fn.avals(self)
|
||||
|
||||
def items(self):
|
||||
return fn.hstore_to_matrix(self)
|
||||
|
||||
def slice(self, *args):
|
||||
return fn.slice(self, Value(list(args), unpack=False))
|
||||
|
||||
def exists(self, key):
|
||||
return fn.exist(self, key)
|
||||
|
||||
def defined(self, key):
|
||||
return fn.defined(self, key)
|
||||
|
||||
def update(self, **data):
|
||||
return Expression(self, HUPDATE, data)
|
||||
|
||||
def delete(self, *keys):
|
||||
return fn.delete(self, Value(list(keys), unpack=False))
|
||||
|
||||
def contains(self, value):
|
||||
if isinstance(value, dict):
|
||||
rhs = Value(value, unpack=False)
|
||||
return Expression(self, HCONTAINS_DICT, rhs)
|
||||
elif isinstance(value, (list, tuple)):
|
||||
rhs = Value(value, unpack=False)
|
||||
return Expression(self, HCONTAINS_KEYS, rhs)
|
||||
return Expression(self, HCONTAINS_KEY, value)
|
||||
|
||||
def contains_any(self, *keys):
|
||||
return Expression(self, HCONTAINS_ANY_KEY, Value(list(keys),
|
||||
unpack=False))
|
||||
|
||||
|
||||
class JSONField(Field):
|
||||
field_type = 'JSON'
|
||||
_json_datatype = 'json'
|
||||
|
||||
def __init__(self, dumps=None, *args, **kwargs):
|
||||
if Json is None:
|
||||
raise Exception('Your version of psycopg2 does not support JSON.')
|
||||
self.dumps = dumps or json.dumps
|
||||
super(JSONField, self).__init__(*args, **kwargs)
|
||||
|
||||
def db_value(self, value):
|
||||
if value is None:
|
||||
return value
|
||||
if not isinstance(value, Json):
|
||||
return Cast(self.dumps(value), self._json_datatype)
|
||||
return value
|
||||
|
||||
def __getitem__(self, value):
|
||||
return JsonLookup(self, [value])
|
||||
|
||||
def path(self, *keys):
|
||||
return JsonPath(self, keys)
|
||||
|
||||
def concat(self, value):
|
||||
return super(JSONField, self).concat(Json(value))
|
||||
|
||||
|
||||
def cast_jsonb(node):
|
||||
return NodeList((node, SQL('::jsonb')), glue='')
|
||||
|
||||
|
||||
class BinaryJSONField(IndexedFieldMixin, JSONField):
|
||||
field_type = 'JSONB'
|
||||
_json_datatype = 'jsonb'
|
||||
__hash__ = Field.__hash__
|
||||
|
||||
def contains(self, other):
|
||||
if isinstance(other, (list, dict)):
|
||||
return Expression(self, JSONB_CONTAINS, Json(other))
|
||||
return Expression(cast_jsonb(self), JSONB_EXISTS, other)
|
||||
|
||||
def contained_by(self, other):
|
||||
return Expression(cast_jsonb(self), JSONB_CONTAINED_BY, Json(other))
|
||||
|
||||
def contains_any(self, *items):
|
||||
return Expression(
|
||||
cast_jsonb(self),
|
||||
JSONB_CONTAINS_ANY_KEY,
|
||||
Value(list(items), unpack=False))
|
||||
|
||||
def contains_all(self, *items):
|
||||
return Expression(
|
||||
cast_jsonb(self),
|
||||
JSONB_CONTAINS_ALL_KEYS,
|
||||
Value(list(items), unpack=False))
|
||||
|
||||
def has_key(self, key):
|
||||
return Expression(cast_jsonb(self), JSONB_CONTAINS_KEY, key)
|
||||
|
||||
def remove(self, *items):
|
||||
return Expression(
|
||||
cast_jsonb(self),
|
||||
JSONB_REMOVE,
|
||||
Value(list(items), unpack=False))
|
||||
|
||||
|
||||
class TSVectorField(IndexedFieldMixin, TextField):
|
||||
field_type = 'TSVECTOR'
|
||||
__hash__ = Field.__hash__
|
||||
|
||||
def match(self, query, language=None, plain=False):
|
||||
params = (language, query) if language is not None else (query,)
|
||||
func = fn.plainto_tsquery if plain else fn.to_tsquery
|
||||
return Expression(self, TS_MATCH, func(*params))
|
||||
|
||||
|
||||
def Match(field, query, language=None):
|
||||
params = (language, query) if language is not None else (query,)
|
||||
field_params = (language, field) if language is not None else (field,)
|
||||
return Expression(
|
||||
fn.to_tsvector(*field_params),
|
||||
TS_MATCH,
|
||||
fn.to_tsquery(*params))
|
||||
|
||||
|
||||
class IntervalField(Field):
|
||||
field_type = 'INTERVAL'
|
||||
|
||||
|
||||
class FetchManyCursor(object):
|
||||
__slots__ = ('cursor', 'array_size', 'exhausted', 'iterable')
|
||||
|
||||
def __init__(self, cursor, array_size=None):
|
||||
self.cursor = cursor
|
||||
self.array_size = array_size or cursor.itersize
|
||||
self.exhausted = False
|
||||
self.iterable = self.row_gen()
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return self.cursor.description
|
||||
|
||||
def close(self):
|
||||
self.cursor.close()
|
||||
|
||||
def row_gen(self):
|
||||
while True:
|
||||
rows = self.cursor.fetchmany(self.array_size)
|
||||
if not rows:
|
||||
return
|
||||
for row in rows:
|
||||
yield row
|
||||
|
||||
def fetchone(self):
|
||||
if self.exhausted:
|
||||
return
|
||||
try:
|
||||
return next(self.iterable)
|
||||
except StopIteration:
|
||||
self.exhausted = True
|
||||
|
||||
|
||||
class ServerSideQuery(Node):
|
||||
def __init__(self, query, array_size=None):
|
||||
self.query = query
|
||||
self.array_size = array_size
|
||||
self._cursor_wrapper = None
|
||||
|
||||
def __sql__(self, ctx):
|
||||
return self.query.__sql__(ctx)
|
||||
|
||||
def __iter__(self):
|
||||
if self._cursor_wrapper is None:
|
||||
self._execute(self.query._database)
|
||||
return iter(self._cursor_wrapper.iterator())
|
||||
|
||||
def _execute(self, database):
|
||||
if self._cursor_wrapper is None:
|
||||
cursor = database.execute(self.query, named_cursor=True,
|
||||
array_size=self.array_size)
|
||||
self._cursor_wrapper = self.query._get_cursor_wrapper(cursor)
|
||||
return self._cursor_wrapper
|
||||
|
||||
|
||||
def ServerSide(query, database=None, array_size=None):
|
||||
if database is None:
|
||||
database = query._database
|
||||
with database.transaction():
|
||||
server_side_query = ServerSideQuery(query, array_size=array_size)
|
||||
for row in server_side_query:
|
||||
yield row
|
||||
|
||||
|
||||
class _empty_object(object):
|
||||
__slots__ = ()
|
||||
def __nonzero__(self):
|
||||
return False
|
||||
__bool__ = __nonzero__
|
||||
|
||||
__named_cursor__ = _empty_object()
|
||||
|
||||
|
||||
class PostgresqlExtDatabase(PostgresqlDatabase):
|
||||
def __init__(self, *args, **kwargs):
|
||||
self._register_hstore = kwargs.pop('register_hstore', False)
|
||||
self._server_side_cursors = kwargs.pop('server_side_cursors', False)
|
||||
super(PostgresqlExtDatabase, self).__init__(*args, **kwargs)
|
||||
|
||||
def _connect(self):
|
||||
conn = super(PostgresqlExtDatabase, self)._connect()
|
||||
if self._register_hstore:
|
||||
register_hstore(conn, globally=True)
|
||||
return conn
|
||||
|
||||
def cursor(self, commit=None):
|
||||
if self.is_closed():
|
||||
if self.autoconnect:
|
||||
self.connect()
|
||||
else:
|
||||
raise InterfaceError('Error, database connection not opened.')
|
||||
if commit is __named_cursor__:
|
||||
return self._state.conn.cursor(name=str(uuid.uuid1()))
|
||||
return self._state.conn.cursor()
|
||||
|
||||
def execute(self, query, commit=SENTINEL, named_cursor=False,
|
||||
array_size=None, **context_options):
|
||||
ctx = self.get_sql_context(**context_options)
|
||||
sql, params = ctx.sql(query).query()
|
||||
named_cursor = named_cursor or (self._server_side_cursors and
|
||||
sql[:6].lower() == 'select')
|
||||
if named_cursor:
|
||||
commit = __named_cursor__
|
||||
cursor = self.execute_sql(sql, params, commit=commit)
|
||||
if named_cursor:
|
||||
cursor = FetchManyCursor(cursor, array_size)
|
||||
return cursor
|
|
@ -1,799 +0,0 @@
|
|||
try:
|
||||
from collections import OrderedDict
|
||||
except ImportError:
|
||||
OrderedDict = dict
|
||||
from collections import namedtuple
|
||||
from inspect import isclass
|
||||
import re
|
||||
|
||||
from peewee import *
|
||||
from peewee import _StringField
|
||||
from peewee import _query_val_transform
|
||||
from peewee import CommaNodeList
|
||||
from peewee import SCOPE_VALUES
|
||||
from peewee import make_snake_case
|
||||
from peewee import text_type
|
||||
try:
|
||||
from pymysql.constants import FIELD_TYPE
|
||||
except ImportError:
|
||||
try:
|
||||
from MySQLdb.constants import FIELD_TYPE
|
||||
except ImportError:
|
||||
FIELD_TYPE = None
|
||||
try:
|
||||
from playhouse import postgres_ext
|
||||
except ImportError:
|
||||
postgres_ext = None
|
||||
|
||||
RESERVED_WORDS = set([
|
||||
'and', 'as', 'assert', 'break', 'class', 'continue', 'def', 'del', 'elif',
|
||||
'else', 'except', 'exec', 'finally', 'for', 'from', 'global', 'if',
|
||||
'import', 'in', 'is', 'lambda', 'not', 'or', 'pass', 'print', 'raise',
|
||||
'return', 'try', 'while', 'with', 'yield',
|
||||
])
|
||||
|
||||
|
||||
class UnknownField(object):
|
||||
pass
|
||||
|
||||
|
||||
class Column(object):
|
||||
"""
|
||||
Store metadata about a database column.
|
||||
"""
|
||||
primary_key_types = (IntegerField, AutoField)
|
||||
|
||||
def __init__(self, name, field_class, raw_column_type, nullable,
|
||||
primary_key=False, column_name=None, index=False,
|
||||
unique=False, default=None, extra_parameters=None):
|
||||
self.name = name
|
||||
self.field_class = field_class
|
||||
self.raw_column_type = raw_column_type
|
||||
self.nullable = nullable
|
||||
self.primary_key = primary_key
|
||||
self.column_name = column_name
|
||||
self.index = index
|
||||
self.unique = unique
|
||||
self.default = default
|
||||
self.extra_parameters = extra_parameters
|
||||
|
||||
# Foreign key metadata.
|
||||
self.rel_model = None
|
||||
self.related_name = None
|
||||
self.to_field = None
|
||||
|
||||
def __repr__(self):
|
||||
attrs = [
|
||||
'field_class',
|
||||
'raw_column_type',
|
||||
'nullable',
|
||||
'primary_key',
|
||||
'column_name']
|
||||
keyword_args = ', '.join(
|
||||
'%s=%s' % (attr, getattr(self, attr))
|
||||
for attr in attrs)
|
||||
return 'Column(%s, %s)' % (self.name, keyword_args)
|
||||
|
||||
def get_field_parameters(self):
|
||||
params = {}
|
||||
if self.extra_parameters is not None:
|
||||
params.update(self.extra_parameters)
|
||||
|
||||
# Set up default attributes.
|
||||
if self.nullable:
|
||||
params['null'] = True
|
||||
if self.field_class is ForeignKeyField or self.name != self.column_name:
|
||||
params['column_name'] = "'%s'" % self.column_name
|
||||
if self.primary_key and not issubclass(self.field_class, AutoField):
|
||||
params['primary_key'] = True
|
||||
if self.default is not None:
|
||||
params['constraints'] = '[SQL("DEFAULT %s")]' % self.default
|
||||
|
||||
# Handle ForeignKeyField-specific attributes.
|
||||
if self.is_foreign_key():
|
||||
params['model'] = self.rel_model
|
||||
if self.to_field:
|
||||
params['field'] = "'%s'" % self.to_field
|
||||
if self.related_name:
|
||||
params['backref'] = "'%s'" % self.related_name
|
||||
|
||||
# Handle indexes on column.
|
||||
if not self.is_primary_key():
|
||||
if self.unique:
|
||||
params['unique'] = 'True'
|
||||
elif self.index and not self.is_foreign_key():
|
||||
params['index'] = 'True'
|
||||
|
||||
return params
|
||||
|
||||
def is_primary_key(self):
|
||||
return self.field_class is AutoField or self.primary_key
|
||||
|
||||
def is_foreign_key(self):
|
||||
return self.field_class is ForeignKeyField
|
||||
|
||||
def is_self_referential_fk(self):
|
||||
return (self.field_class is ForeignKeyField and
|
||||
self.rel_model == "'self'")
|
||||
|
||||
def set_foreign_key(self, foreign_key, model_names, dest=None,
|
||||
related_name=None):
|
||||
self.foreign_key = foreign_key
|
||||
self.field_class = ForeignKeyField
|
||||
if foreign_key.dest_table == foreign_key.table:
|
||||
self.rel_model = "'self'"
|
||||
else:
|
||||
self.rel_model = model_names[foreign_key.dest_table]
|
||||
self.to_field = dest and dest.name or None
|
||||
self.related_name = related_name or None
|
||||
|
||||
def get_field(self):
|
||||
# Generate the field definition for this column.
|
||||
field_params = {}
|
||||
for key, value in self.get_field_parameters().items():
|
||||
if isclass(value) and issubclass(value, Field):
|
||||
value = value.__name__
|
||||
field_params[key] = value
|
||||
|
||||
param_str = ', '.join('%s=%s' % (k, v)
|
||||
for k, v in sorted(field_params.items()))
|
||||
field = '%s = %s(%s)' % (
|
||||
self.name,
|
||||
self.field_class.__name__,
|
||||
param_str)
|
||||
|
||||
if self.field_class is UnknownField:
|
||||
field = '%s # %s' % (field, self.raw_column_type)
|
||||
|
||||
return field
|
||||
|
||||
|
||||
class Metadata(object):
|
||||
column_map = {}
|
||||
extension_import = ''
|
||||
|
||||
def __init__(self, database):
|
||||
self.database = database
|
||||
self.requires_extension = False
|
||||
|
||||
def execute(self, sql, *params):
|
||||
return self.database.execute_sql(sql, params)
|
||||
|
||||
def get_columns(self, table, schema=None):
|
||||
metadata = OrderedDict(
|
||||
(metadata.name, metadata)
|
||||
for metadata in self.database.get_columns(table, schema))
|
||||
|
||||
# Look up the actual column type for each column.
|
||||
column_types, extra_params = self.get_column_types(table, schema)
|
||||
|
||||
# Look up the primary keys.
|
||||
pk_names = self.get_primary_keys(table, schema)
|
||||
if len(pk_names) == 1:
|
||||
pk = pk_names[0]
|
||||
if column_types[pk] is IntegerField:
|
||||
column_types[pk] = AutoField
|
||||
elif column_types[pk] is BigIntegerField:
|
||||
column_types[pk] = BigAutoField
|
||||
|
||||
columns = OrderedDict()
|
||||
for name, column_data in metadata.items():
|
||||
field_class = column_types[name]
|
||||
default = self._clean_default(field_class, column_data.default)
|
||||
|
||||
columns[name] = Column(
|
||||
name,
|
||||
field_class=field_class,
|
||||
raw_column_type=column_data.data_type,
|
||||
nullable=column_data.null,
|
||||
primary_key=column_data.primary_key,
|
||||
column_name=name,
|
||||
default=default,
|
||||
extra_parameters=extra_params.get(name))
|
||||
|
||||
return columns
|
||||
|
||||
def get_column_types(self, table, schema=None):
|
||||
raise NotImplementedError
|
||||
|
||||
def _clean_default(self, field_class, default):
|
||||
if default is None or field_class in (AutoField, BigAutoField) or \
|
||||
default.lower() == 'null':
|
||||
return
|
||||
if issubclass(field_class, _StringField) and \
|
||||
isinstance(default, text_type) and not default.startswith("'"):
|
||||
default = "'%s'" % default
|
||||
return default or "''"
|
||||
|
||||
def get_foreign_keys(self, table, schema=None):
|
||||
return self.database.get_foreign_keys(table, schema)
|
||||
|
||||
def get_primary_keys(self, table, schema=None):
|
||||
return self.database.get_primary_keys(table, schema)
|
||||
|
||||
def get_indexes(self, table, schema=None):
|
||||
return self.database.get_indexes(table, schema)
|
||||
|
||||
|
||||
class PostgresqlMetadata(Metadata):
|
||||
column_map = {
|
||||
16: BooleanField,
|
||||
17: BlobField,
|
||||
20: BigIntegerField,
|
||||
21: IntegerField,
|
||||
23: IntegerField,
|
||||
25: TextField,
|
||||
700: FloatField,
|
||||
701: DoubleField,
|
||||
1042: CharField, # blank-padded CHAR
|
||||
1043: CharField,
|
||||
1082: DateField,
|
||||
1114: DateTimeField,
|
||||
1184: DateTimeField,
|
||||
1083: TimeField,
|
||||
1266: TimeField,
|
||||
1700: DecimalField,
|
||||
2950: TextField, # UUID
|
||||
}
|
||||
array_types = {
|
||||
1000: BooleanField,
|
||||
1001: BlobField,
|
||||
1005: SmallIntegerField,
|
||||
1007: IntegerField,
|
||||
1009: TextField,
|
||||
1014: CharField,
|
||||
1015: CharField,
|
||||
1016: BigIntegerField,
|
||||
1115: DateTimeField,
|
||||
1182: DateField,
|
||||
1183: TimeField,
|
||||
}
|
||||
extension_import = 'from playhouse.postgres_ext import *'
|
||||
|
||||
def __init__(self, database):
|
||||
super(PostgresqlMetadata, self).__init__(database)
|
||||
|
||||
if postgres_ext is not None:
|
||||
# Attempt to add types like HStore and JSON.
|
||||
cursor = self.execute('select oid, typname, format_type(oid, NULL)'
|
||||
' from pg_type;')
|
||||
results = cursor.fetchall()
|
||||
|
||||
for oid, typname, formatted_type in results:
|
||||
if typname == 'json':
|
||||
self.column_map[oid] = postgres_ext.JSONField
|
||||
elif typname == 'jsonb':
|
||||
self.column_map[oid] = postgres_ext.BinaryJSONField
|
||||
elif typname == 'hstore':
|
||||
self.column_map[oid] = postgres_ext.HStoreField
|
||||
elif typname == 'tsvector':
|
||||
self.column_map[oid] = postgres_ext.TSVectorField
|
||||
|
||||
for oid in self.array_types:
|
||||
self.column_map[oid] = postgres_ext.ArrayField
|
||||
|
||||
def get_column_types(self, table, schema):
|
||||
column_types = {}
|
||||
extra_params = {}
|
||||
extension_types = set((
|
||||
postgres_ext.ArrayField,
|
||||
postgres_ext.BinaryJSONField,
|
||||
postgres_ext.JSONField,
|
||||
postgres_ext.TSVectorField,
|
||||
postgres_ext.HStoreField)) if postgres_ext is not None else set()
|
||||
|
||||
# Look up the actual column type for each column.
|
||||
identifier = '"%s"."%s"' % (schema, table)
|
||||
cursor = self.execute('SELECT * FROM %s LIMIT 1' % identifier)
|
||||
|
||||
# Store column metadata in dictionary keyed by column name.
|
||||
for column_description in cursor.description:
|
||||
name = column_description.name
|
||||
oid = column_description.type_code
|
||||
column_types[name] = self.column_map.get(oid, UnknownField)
|
||||
if column_types[name] in extension_types:
|
||||
self.requires_extension = True
|
||||
if oid in self.array_types:
|
||||
extra_params[name] = {'field_class': self.array_types[oid]}
|
||||
|
||||
return column_types, extra_params
|
||||
|
||||
def get_columns(self, table, schema=None):
|
||||
schema = schema or 'public'
|
||||
return super(PostgresqlMetadata, self).get_columns(table, schema)
|
||||
|
||||
def get_foreign_keys(self, table, schema=None):
|
||||
schema = schema or 'public'
|
||||
return super(PostgresqlMetadata, self).get_foreign_keys(table, schema)
|
||||
|
||||
def get_primary_keys(self, table, schema=None):
|
||||
schema = schema or 'public'
|
||||
return super(PostgresqlMetadata, self).get_primary_keys(table, schema)
|
||||
|
||||
def get_indexes(self, table, schema=None):
|
||||
schema = schema or 'public'
|
||||
return super(PostgresqlMetadata, self).get_indexes(table, schema)
|
||||
|
||||
|
||||
class MySQLMetadata(Metadata):
|
||||
if FIELD_TYPE is None:
|
||||
column_map = {}
|
||||
else:
|
||||
column_map = {
|
||||
FIELD_TYPE.BLOB: TextField,
|
||||
FIELD_TYPE.CHAR: CharField,
|
||||
FIELD_TYPE.DATE: DateField,
|
||||
FIELD_TYPE.DATETIME: DateTimeField,
|
||||
FIELD_TYPE.DECIMAL: DecimalField,
|
||||
FIELD_TYPE.DOUBLE: FloatField,
|
||||
FIELD_TYPE.FLOAT: FloatField,
|
||||
FIELD_TYPE.INT24: IntegerField,
|
||||
FIELD_TYPE.LONG_BLOB: TextField,
|
||||
FIELD_TYPE.LONG: IntegerField,
|
||||
FIELD_TYPE.LONGLONG: BigIntegerField,
|
||||
FIELD_TYPE.MEDIUM_BLOB: TextField,
|
||||
FIELD_TYPE.NEWDECIMAL: DecimalField,
|
||||
FIELD_TYPE.SHORT: IntegerField,
|
||||
FIELD_TYPE.STRING: CharField,
|
||||
FIELD_TYPE.TIMESTAMP: DateTimeField,
|
||||
FIELD_TYPE.TIME: TimeField,
|
||||
FIELD_TYPE.TINY_BLOB: TextField,
|
||||
FIELD_TYPE.TINY: IntegerField,
|
||||
FIELD_TYPE.VAR_STRING: CharField,
|
||||
}
|
||||
|
||||
def __init__(self, database, **kwargs):
|
||||
if 'password' in kwargs:
|
||||
kwargs['passwd'] = kwargs.pop('password')
|
||||
super(MySQLMetadata, self).__init__(database, **kwargs)
|
||||
|
||||
def get_column_types(self, table, schema=None):
|
||||
column_types = {}
|
||||
|
||||
# Look up the actual column type for each column.
|
||||
cursor = self.execute('SELECT * FROM `%s` LIMIT 1' % table)
|
||||
|
||||
# Store column metadata in dictionary keyed by column name.
|
||||
for column_description in cursor.description:
|
||||
name, type_code = column_description[:2]
|
||||
column_types[name] = self.column_map.get(type_code, UnknownField)
|
||||
|
||||
return column_types, {}
|
||||
|
||||
|
||||
class SqliteMetadata(Metadata):
|
||||
column_map = {
|
||||
'bigint': BigIntegerField,
|
||||
'blob': BlobField,
|
||||
'bool': BooleanField,
|
||||
'boolean': BooleanField,
|
||||
'char': CharField,
|
||||
'date': DateField,
|
||||
'datetime': DateTimeField,
|
||||
'decimal': DecimalField,
|
||||
'float': FloatField,
|
||||
'integer': IntegerField,
|
||||
'integer unsigned': IntegerField,
|
||||
'int': IntegerField,
|
||||
'long': BigIntegerField,
|
||||
'numeric': DecimalField,
|
||||
'real': FloatField,
|
||||
'smallinteger': IntegerField,
|
||||
'smallint': IntegerField,
|
||||
'smallint unsigned': IntegerField,
|
||||
'text': TextField,
|
||||
'time': TimeField,
|
||||
'varchar': CharField,
|
||||
}
|
||||
|
||||
begin = '(?:["\[\(]+)?'
|
||||
end = '(?:["\]\)]+)?'
|
||||
re_foreign_key = (
|
||||
'(?:FOREIGN KEY\s*)?'
|
||||
'{begin}(.+?){end}\s+(?:.+\s+)?'
|
||||
'references\s+{begin}(.+?){end}'
|
||||
'\s*\(["|\[]?(.+?)["|\]]?\)').format(begin=begin, end=end)
|
||||
re_varchar = r'^\s*(?:var)?char\s*\(\s*(\d+)\s*\)\s*$'
|
||||
|
||||
def _map_col(self, column_type):
|
||||
raw_column_type = column_type.lower()
|
||||
if raw_column_type in self.column_map:
|
||||
field_class = self.column_map[raw_column_type]
|
||||
elif re.search(self.re_varchar, raw_column_type):
|
||||
field_class = CharField
|
||||
else:
|
||||
column_type = re.sub('\(.+\)', '', raw_column_type)
|
||||
if column_type == '':
|
||||
field_class = BareField
|
||||
else:
|
||||
field_class = self.column_map.get(column_type, UnknownField)
|
||||
return field_class
|
||||
|
||||
def get_column_types(self, table, schema=None):
|
||||
column_types = {}
|
||||
columns = self.database.get_columns(table)
|
||||
|
||||
for column in columns:
|
||||
column_types[column.name] = self._map_col(column.data_type)
|
||||
|
||||
return column_types, {}
|
||||
|
||||
|
||||
_DatabaseMetadata = namedtuple('_DatabaseMetadata', (
|
||||
'columns',
|
||||
'primary_keys',
|
||||
'foreign_keys',
|
||||
'model_names',
|
||||
'indexes'))
|
||||
|
||||
|
||||
class DatabaseMetadata(_DatabaseMetadata):
|
||||
def multi_column_indexes(self, table):
|
||||
accum = []
|
||||
for index in self.indexes[table]:
|
||||
if len(index.columns) > 1:
|
||||
field_names = [self.columns[table][column].name
|
||||
for column in index.columns
|
||||
if column in self.columns[table]]
|
||||
accum.append((field_names, index.unique))
|
||||
return accum
|
||||
|
||||
def column_indexes(self, table):
|
||||
accum = {}
|
||||
for index in self.indexes[table]:
|
||||
if len(index.columns) == 1:
|
||||
accum[index.columns[0]] = index.unique
|
||||
return accum
|
||||
|
||||
|
||||
class Introspector(object):
|
||||
pk_classes = [AutoField, IntegerField]
|
||||
|
||||
def __init__(self, metadata, schema=None):
|
||||
self.metadata = metadata
|
||||
self.schema = schema
|
||||
|
||||
def __repr__(self):
|
||||
return '<Introspector: %s>' % self.metadata.database
|
||||
|
||||
@classmethod
|
||||
def from_database(cls, database, schema=None):
|
||||
if isinstance(database, PostgresqlDatabase):
|
||||
metadata = PostgresqlMetadata(database)
|
||||
elif isinstance(database, MySQLDatabase):
|
||||
metadata = MySQLMetadata(database)
|
||||
elif isinstance(database, SqliteDatabase):
|
||||
metadata = SqliteMetadata(database)
|
||||
else:
|
||||
raise ValueError('Introspection not supported for %r' % database)
|
||||
return cls(metadata, schema=schema)
|
||||
|
||||
def get_database_class(self):
|
||||
return type(self.metadata.database)
|
||||
|
||||
def get_database_name(self):
|
||||
return self.metadata.database.database
|
||||
|
||||
def get_database_kwargs(self):
|
||||
return self.metadata.database.connect_params
|
||||
|
||||
def get_additional_imports(self):
|
||||
if self.metadata.requires_extension:
|
||||
return '\n' + self.metadata.extension_import
|
||||
return ''
|
||||
|
||||
def make_model_name(self, table, snake_case=True):
|
||||
if snake_case:
|
||||
table = make_snake_case(table)
|
||||
model = re.sub('[^\w]+', '', table)
|
||||
model_name = ''.join(sub.title() for sub in model.split('_'))
|
||||
if not model_name[0].isalpha():
|
||||
model_name = 'T' + model_name
|
||||
return model_name
|
||||
|
||||
def make_column_name(self, column, is_foreign_key=False, snake_case=True):
|
||||
column = column.strip()
|
||||
if snake_case:
|
||||
column = make_snake_case(column)
|
||||
column = column.lower()
|
||||
if is_foreign_key:
|
||||
# Strip "_id" from foreign keys, unless the foreign-key happens to
|
||||
# be named "_id", in which case the name is retained.
|
||||
column = re.sub('_id$', '', column) or column
|
||||
|
||||
# Remove characters that are invalid for Python identifiers.
|
||||
column = re.sub('[^\w]+', '_', column)
|
||||
if column in RESERVED_WORDS:
|
||||
column += '_'
|
||||
if len(column) and column[0].isdigit():
|
||||
column = '_' + column
|
||||
return column
|
||||
|
||||
def introspect(self, table_names=None, literal_column_names=False,
|
||||
include_views=False, snake_case=True):
|
||||
# Retrieve all the tables in the database.
|
||||
tables = self.metadata.database.get_tables(schema=self.schema)
|
||||
if include_views:
|
||||
views = self.metadata.database.get_views(schema=self.schema)
|
||||
tables.extend([view.name for view in views])
|
||||
|
||||
if table_names is not None:
|
||||
tables = [table for table in tables if table in table_names]
|
||||
table_set = set(tables)
|
||||
|
||||
# Store a mapping of table name -> dictionary of columns.
|
||||
columns = {}
|
||||
|
||||
# Store a mapping of table name -> set of primary key columns.
|
||||
primary_keys = {}
|
||||
|
||||
# Store a mapping of table -> foreign keys.
|
||||
foreign_keys = {}
|
||||
|
||||
# Store a mapping of table name -> model name.
|
||||
model_names = {}
|
||||
|
||||
# Store a mapping of table name -> indexes.
|
||||
indexes = {}
|
||||
|
||||
# Gather the columns for each table.
|
||||
for table in tables:
|
||||
table_indexes = self.metadata.get_indexes(table, self.schema)
|
||||
table_columns = self.metadata.get_columns(table, self.schema)
|
||||
try:
|
||||
foreign_keys[table] = self.metadata.get_foreign_keys(
|
||||
table, self.schema)
|
||||
except ValueError as exc:
|
||||
err(*exc.args)
|
||||
foreign_keys[table] = []
|
||||
else:
|
||||
# If there is a possibility we could exclude a dependent table,
|
||||
# ensure that we introspect it so FKs will work.
|
||||
if table_names is not None:
|
||||
for foreign_key in foreign_keys[table]:
|
||||
if foreign_key.dest_table not in table_set:
|
||||
tables.append(foreign_key.dest_table)
|
||||
table_set.add(foreign_key.dest_table)
|
||||
|
||||
model_names[table] = self.make_model_name(table, snake_case)
|
||||
|
||||
# Collect sets of all the column names as well as all the
|
||||
# foreign-key column names.
|
||||
lower_col_names = set(column_name.lower()
|
||||
for column_name in table_columns)
|
||||
fks = set(fk_col.column for fk_col in foreign_keys[table])
|
||||
|
||||
for col_name, column in table_columns.items():
|
||||
if literal_column_names:
|
||||
new_name = re.sub('[^\w]+', '_', col_name)
|
||||
else:
|
||||
new_name = self.make_column_name(col_name, col_name in fks,
|
||||
snake_case)
|
||||
|
||||
# If we have two columns, "parent" and "parent_id", ensure
|
||||
# that when we don't introduce naming conflicts.
|
||||
lower_name = col_name.lower()
|
||||
if lower_name.endswith('_id') and new_name in lower_col_names:
|
||||
new_name = col_name.lower()
|
||||
|
||||
column.name = new_name
|
||||
|
||||
for index in table_indexes:
|
||||
if len(index.columns) == 1:
|
||||
column = index.columns[0]
|
||||
if column in table_columns:
|
||||
table_columns[column].unique = index.unique
|
||||
table_columns[column].index = True
|
||||
|
||||
primary_keys[table] = self.metadata.get_primary_keys(
|
||||
table, self.schema)
|
||||
columns[table] = table_columns
|
||||
indexes[table] = table_indexes
|
||||
|
||||
# Gather all instances where we might have a `related_name` conflict,
|
||||
# either due to multiple FKs on a table pointing to the same table,
|
||||
# or a related_name that would conflict with an existing field.
|
||||
related_names = {}
|
||||
sort_fn = lambda foreign_key: foreign_key.column
|
||||
for table in tables:
|
||||
models_referenced = set()
|
||||
for foreign_key in sorted(foreign_keys[table], key=sort_fn):
|
||||
try:
|
||||
column = columns[table][foreign_key.column]
|
||||
except KeyError:
|
||||
continue
|
||||
|
||||
dest_table = foreign_key.dest_table
|
||||
if dest_table in models_referenced:
|
||||
related_names[column] = '%s_%s_set' % (
|
||||
dest_table,
|
||||
column.name)
|
||||
else:
|
||||
models_referenced.add(dest_table)
|
||||
|
||||
# On the second pass convert all foreign keys.
|
||||
for table in tables:
|
||||
for foreign_key in foreign_keys[table]:
|
||||
src = columns[foreign_key.table][foreign_key.column]
|
||||
try:
|
||||
dest = columns[foreign_key.dest_table][
|
||||
foreign_key.dest_column]
|
||||
except KeyError:
|
||||
dest = None
|
||||
|
||||
src.set_foreign_key(
|
||||
foreign_key=foreign_key,
|
||||
model_names=model_names,
|
||||
dest=dest,
|
||||
related_name=related_names.get(src))
|
||||
|
||||
return DatabaseMetadata(
|
||||
columns,
|
||||
primary_keys,
|
||||
foreign_keys,
|
||||
model_names,
|
||||
indexes)
|
||||
|
||||
def generate_models(self, skip_invalid=False, table_names=None,
|
||||
literal_column_names=False, bare_fields=False,
|
||||
include_views=False):
|
||||
database = self.introspect(table_names, literal_column_names,
|
||||
include_views)
|
||||
models = {}
|
||||
|
||||
class BaseModel(Model):
|
||||
class Meta:
|
||||
database = self.metadata.database
|
||||
schema = self.schema
|
||||
|
||||
def _create_model(table, models):
|
||||
for foreign_key in database.foreign_keys[table]:
|
||||
dest = foreign_key.dest_table
|
||||
|
||||
if dest not in models and dest != table:
|
||||
_create_model(dest, models)
|
||||
|
||||
primary_keys = []
|
||||
columns = database.columns[table]
|
||||
for column_name, column in columns.items():
|
||||
if column.primary_key:
|
||||
primary_keys.append(column.name)
|
||||
|
||||
multi_column_indexes = database.multi_column_indexes(table)
|
||||
column_indexes = database.column_indexes(table)
|
||||
|
||||
class Meta:
|
||||
indexes = multi_column_indexes
|
||||
table_name = table
|
||||
|
||||
# Fix models with multi-column primary keys.
|
||||
composite_key = False
|
||||
if len(primary_keys) == 0:
|
||||
primary_keys = columns.keys()
|
||||
if len(primary_keys) > 1:
|
||||
Meta.primary_key = CompositeKey(*[
|
||||
field.name for col, field in columns.items()
|
||||
if col in primary_keys])
|
||||
composite_key = True
|
||||
|
||||
attrs = {'Meta': Meta}
|
||||
for column_name, column in columns.items():
|
||||
FieldClass = column.field_class
|
||||
if FieldClass is not ForeignKeyField and bare_fields:
|
||||
FieldClass = BareField
|
||||
elif FieldClass is UnknownField:
|
||||
FieldClass = BareField
|
||||
|
||||
params = {
|
||||
'column_name': column_name,
|
||||
'null': column.nullable}
|
||||
if column.primary_key and composite_key:
|
||||
if FieldClass is AutoField:
|
||||
FieldClass = IntegerField
|
||||
params['primary_key'] = False
|
||||
elif column.primary_key and FieldClass is not AutoField:
|
||||
params['primary_key'] = True
|
||||
if column.is_foreign_key():
|
||||
if column.is_self_referential_fk():
|
||||
params['model'] = 'self'
|
||||
else:
|
||||
dest_table = column.foreign_key.dest_table
|
||||
params['model'] = models[dest_table]
|
||||
if column.to_field:
|
||||
params['field'] = column.to_field
|
||||
|
||||
# Generate a unique related name.
|
||||
params['backref'] = '%s_%s_rel' % (table, column_name)
|
||||
|
||||
if column.default is not None:
|
||||
constraint = SQL('DEFAULT %s' % column.default)
|
||||
params['constraints'] = [constraint]
|
||||
|
||||
if column_name in column_indexes and not \
|
||||
column.is_primary_key():
|
||||
if column_indexes[column_name]:
|
||||
params['unique'] = True
|
||||
elif not column.is_foreign_key():
|
||||
params['index'] = True
|
||||
|
||||
attrs[column.name] = FieldClass(**params)
|
||||
|
||||
try:
|
||||
models[table] = type(str(table), (BaseModel,), attrs)
|
||||
except ValueError:
|
||||
if not skip_invalid:
|
||||
raise
|
||||
|
||||
# Actually generate Model classes.
|
||||
for table, model in sorted(database.model_names.items()):
|
||||
if table not in models:
|
||||
_create_model(table, models)
|
||||
|
||||
return models
|
||||
|
||||
|
||||
def introspect(database, schema=None):
|
||||
introspector = Introspector.from_database(database, schema=schema)
|
||||
return introspector.introspect()
|
||||
|
||||
|
||||
def generate_models(database, schema=None, **options):
|
||||
introspector = Introspector.from_database(database, schema=schema)
|
||||
return introspector.generate_models(**options)
|
||||
|
||||
|
||||
def print_model(model, indexes=True, inline_indexes=False):
|
||||
print(model._meta.name)
|
||||
for field in model._meta.sorted_fields:
|
||||
parts = [' %s %s' % (field.name, field.field_type)]
|
||||
if field.primary_key:
|
||||
parts.append(' PK')
|
||||
elif inline_indexes:
|
||||
if field.unique:
|
||||
parts.append(' UNIQUE')
|
||||
elif field.index:
|
||||
parts.append(' INDEX')
|
||||
if isinstance(field, ForeignKeyField):
|
||||
parts.append(' FK: %s.%s' % (field.rel_model.__name__,
|
||||
field.rel_field.name))
|
||||
print(''.join(parts))
|
||||
|
||||
if indexes:
|
||||
index_list = model._meta.fields_to_index()
|
||||
if not index_list:
|
||||
return
|
||||
|
||||
print('\nindex(es)')
|
||||
for index in index_list:
|
||||
parts = [' ']
|
||||
ctx = model._meta.database.get_sql_context()
|
||||
with ctx.scope_values(param='%s', quote='""'):
|
||||
ctx.sql(CommaNodeList(index._expressions))
|
||||
if index._where:
|
||||
ctx.literal(' WHERE ')
|
||||
ctx.sql(index._where)
|
||||
sql, params = ctx.query()
|
||||
|
||||
clean = sql % tuple(map(_query_val_transform, params))
|
||||
parts.append(clean.replace('"', ''))
|
||||
|
||||
if index._unique:
|
||||
parts.append(' UNIQUE')
|
||||
print(''.join(parts))
|
||||
|
||||
|
||||
def get_table_sql(model):
|
||||
sql, params = model._schema._create_table().query()
|
||||
if model._meta.database.param != '%s':
|
||||
sql = sql.replace(model._meta.database.param, '%s')
|
||||
|
||||
# Format and indent the table declaration, simplest possible approach.
|
||||
match_obj = re.match('^(.+?\()(.+)(\).*)', sql)
|
||||
create, columns, extra = match_obj.groups()
|
||||
indented = ',\n'.join(' %s' % column for column in columns.split(', '))
|
||||
|
||||
clean = '\n'.join((create, indented, extra)).strip()
|
||||
return clean % tuple(map(_query_val_transform, params))
|
||||
|
||||
def print_table_sql(model):
|
||||
print(get_table_sql(model))
|
|
@ -1,228 +0,0 @@
|
|||
from peewee import *
|
||||
from peewee import Alias
|
||||
from peewee import SENTINEL
|
||||
from peewee import callable_
|
||||
|
||||
|
||||
_clone_set = lambda s: set(s) if s else set()
|
||||
|
||||
|
||||
def model_to_dict(model, recurse=True, backrefs=False, only=None,
|
||||
exclude=None, seen=None, extra_attrs=None,
|
||||
fields_from_query=None, max_depth=None, manytomany=False):
|
||||
"""
|
||||
Convert a model instance (and any related objects) to a dictionary.
|
||||
|
||||
:param bool recurse: Whether foreign-keys should be recursed.
|
||||
:param bool backrefs: Whether lists of related objects should be recursed.
|
||||
:param only: A list (or set) of field instances indicating which fields
|
||||
should be included.
|
||||
:param exclude: A list (or set) of field instances that should be
|
||||
excluded from the dictionary.
|
||||
:param list extra_attrs: Names of model instance attributes or methods
|
||||
that should be included.
|
||||
:param SelectQuery fields_from_query: Query that was source of model. Take
|
||||
fields explicitly selected by the query and serialize them.
|
||||
:param int max_depth: Maximum depth to recurse, value <= 0 means no max.
|
||||
:param bool manytomany: Process many-to-many fields.
|
||||
"""
|
||||
max_depth = -1 if max_depth is None else max_depth
|
||||
if max_depth == 0:
|
||||
recurse = False
|
||||
|
||||
only = _clone_set(only)
|
||||
extra_attrs = _clone_set(extra_attrs)
|
||||
should_skip = lambda n: (n in exclude) or (only and (n not in only))
|
||||
|
||||
if fields_from_query is not None:
|
||||
for item in fields_from_query._returning:
|
||||
if isinstance(item, Field):
|
||||
only.add(item)
|
||||
elif isinstance(item, Alias):
|
||||
extra_attrs.add(item._alias)
|
||||
|
||||
data = {}
|
||||
exclude = _clone_set(exclude)
|
||||
seen = _clone_set(seen)
|
||||
exclude |= seen
|
||||
model_class = type(model)
|
||||
|
||||
if manytomany:
|
||||
for name, m2m in model._meta.manytomany.items():
|
||||
if should_skip(name):
|
||||
continue
|
||||
|
||||
exclude.update((m2m, m2m.rel_model._meta.manytomany[m2m.backref]))
|
||||
for fkf in m2m.through_model._meta.refs:
|
||||
exclude.add(fkf)
|
||||
|
||||
accum = []
|
||||
for rel_obj in getattr(model, name):
|
||||
accum.append(model_to_dict(
|
||||
rel_obj,
|
||||
recurse=recurse,
|
||||
backrefs=backrefs,
|
||||
only=only,
|
||||
exclude=exclude,
|
||||
max_depth=max_depth - 1))
|
||||
data[name] = accum
|
||||
|
||||
for field in model._meta.sorted_fields:
|
||||
if should_skip(field):
|
||||
continue
|
||||
|
||||
field_data = model.__data__.get(field.name)
|
||||
if isinstance(field, ForeignKeyField) and recurse:
|
||||
if field_data is not None:
|
||||
seen.add(field)
|
||||
rel_obj = getattr(model, field.name)
|
||||
field_data = model_to_dict(
|
||||
rel_obj,
|
||||
recurse=recurse,
|
||||
backrefs=backrefs,
|
||||
only=only,
|
||||
exclude=exclude,
|
||||
seen=seen,
|
||||
max_depth=max_depth - 1)
|
||||
else:
|
||||
field_data = None
|
||||
|
||||
data[field.name] = field_data
|
||||
|
||||
if extra_attrs:
|
||||
for attr_name in extra_attrs:
|
||||
attr = getattr(model, attr_name)
|
||||
if callable_(attr):
|
||||
data[attr_name] = attr()
|
||||
else:
|
||||
data[attr_name] = attr
|
||||
|
||||
if backrefs and recurse:
|
||||
for foreign_key, rel_model in model._meta.backrefs.items():
|
||||
if foreign_key.backref == '+': continue
|
||||
descriptor = getattr(model_class, foreign_key.backref)
|
||||
if descriptor in exclude or foreign_key in exclude:
|
||||
continue
|
||||
if only and (descriptor not in only) and (foreign_key not in only):
|
||||
continue
|
||||
|
||||
accum = []
|
||||
exclude.add(foreign_key)
|
||||
related_query = getattr(model, foreign_key.backref)
|
||||
|
||||
for rel_obj in related_query:
|
||||
accum.append(model_to_dict(
|
||||
rel_obj,
|
||||
recurse=recurse,
|
||||
backrefs=backrefs,
|
||||
only=only,
|
||||
exclude=exclude,
|
||||
max_depth=max_depth - 1))
|
||||
|
||||
data[foreign_key.backref] = accum
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def update_model_from_dict(instance, data, ignore_unknown=False):
|
||||
meta = instance._meta
|
||||
backrefs = dict([(fk.backref, fk) for fk in meta.backrefs])
|
||||
|
||||
for key, value in data.items():
|
||||
if key in meta.combined:
|
||||
field = meta.combined[key]
|
||||
is_backref = False
|
||||
elif key in backrefs:
|
||||
field = backrefs[key]
|
||||
is_backref = True
|
||||
elif ignore_unknown:
|
||||
setattr(instance, key, value)
|
||||
continue
|
||||
else:
|
||||
raise AttributeError('Unrecognized attribute "%s" for model '
|
||||
'class %s.' % (key, type(instance)))
|
||||
|
||||
is_foreign_key = isinstance(field, ForeignKeyField)
|
||||
|
||||
if not is_backref and is_foreign_key and isinstance(value, dict):
|
||||
try:
|
||||
rel_instance = instance.__rel__[field.name]
|
||||
except KeyError:
|
||||
rel_instance = field.rel_model()
|
||||
setattr(
|
||||
instance,
|
||||
field.name,
|
||||
update_model_from_dict(rel_instance, value, ignore_unknown))
|
||||
elif is_backref and isinstance(value, (list, tuple)):
|
||||
instances = [
|
||||
dict_to_model(field.model, row_data, ignore_unknown)
|
||||
for row_data in value]
|
||||
for rel_instance in instances:
|
||||
setattr(rel_instance, field.name, instance)
|
||||
setattr(instance, field.backref, instances)
|
||||
else:
|
||||
setattr(instance, field.name, value)
|
||||
|
||||
return instance
|
||||
|
||||
|
||||
def dict_to_model(model_class, data, ignore_unknown=False):
|
||||
return update_model_from_dict(model_class(), data, ignore_unknown)
|
||||
|
||||
|
||||
class ReconnectMixin(object):
|
||||
"""
|
||||
Mixin class that attempts to automatically reconnect to the database under
|
||||
certain error conditions.
|
||||
|
||||
For example, MySQL servers will typically close connections that are idle
|
||||
for 28800 seconds ("wait_timeout" setting). If your application makes use
|
||||
of long-lived connections, you may find your connections are closed after
|
||||
a period of no activity. This mixin will attempt to reconnect automatically
|
||||
when these errors occur.
|
||||
|
||||
This mixin class probably should not be used with Postgres (unless you
|
||||
REALLY know what you are doing) and definitely has no business being used
|
||||
with Sqlite. If you wish to use with Postgres, you will need to adapt the
|
||||
`reconnect_errors` attribute to something appropriate for Postgres.
|
||||
"""
|
||||
reconnect_errors = (
|
||||
# Error class, error message fragment (or empty string for all).
|
||||
(OperationalError, '2006'), # MySQL server has gone away.
|
||||
(OperationalError, '2013'), # Lost connection to MySQL server.
|
||||
(OperationalError, '2014'), # Commands out of sync.
|
||||
|
||||
# mysql-connector raises a slightly different error when an idle
|
||||
# connection is terminated by the server. This is equivalent to 2013.
|
||||
(OperationalError, 'MySQL Connection not available.'),
|
||||
)
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(ReconnectMixin, self).__init__(*args, **kwargs)
|
||||
|
||||
# Normalize the reconnect errors to a more efficient data-structure.
|
||||
self._reconnect_errors = {}
|
||||
for exc_class, err_fragment in self.reconnect_errors:
|
||||
self._reconnect_errors.setdefault(exc_class, [])
|
||||
self._reconnect_errors[exc_class].append(err_fragment.lower())
|
||||
|
||||
def execute_sql(self, sql, params=None, commit=SENTINEL):
|
||||
try:
|
||||
return super(ReconnectMixin, self).execute_sql(sql, params, commit)
|
||||
except Exception as exc:
|
||||
exc_class = type(exc)
|
||||
if exc_class not in self._reconnect_errors:
|
||||
raise exc
|
||||
|
||||
exc_repr = str(exc).lower()
|
||||
for err_fragment in self._reconnect_errors[exc_class]:
|
||||
if err_fragment in exc_repr:
|
||||
break
|
||||
else:
|
||||
raise exc
|
||||
|
||||
if not self.is_closed():
|
||||
self.close()
|
||||
self.connect()
|
||||
|
||||
return super(ReconnectMixin, self).execute_sql(sql, params, commit)
|
|
@ -1,79 +0,0 @@
|
|||
"""
|
||||
Provide django-style hooks for model events.
|
||||
"""
|
||||
from peewee import Model as _Model
|
||||
|
||||
|
||||
class Signal(object):
|
||||
def __init__(self):
|
||||
self._flush()
|
||||
|
||||
def _flush(self):
|
||||
self._receivers = set()
|
||||
self._receiver_list = []
|
||||
|
||||
def connect(self, receiver, name=None, sender=None):
|
||||
name = name or receiver.__name__
|
||||
key = (name, sender)
|
||||
if key not in self._receivers:
|
||||
self._receivers.add(key)
|
||||
self._receiver_list.append((name, receiver, sender))
|
||||
else:
|
||||
raise ValueError('receiver named %s (for sender=%s) already '
|
||||
'connected' % (name, sender or 'any'))
|
||||
|
||||
def disconnect(self, receiver=None, name=None, sender=None):
|
||||
if receiver:
|
||||
name = name or receiver.__name__
|
||||
if not name:
|
||||
raise ValueError('a receiver or a name must be provided')
|
||||
|
||||
key = (name, sender)
|
||||
if key not in self._receivers:
|
||||
raise ValueError('receiver named %s for sender=%s not found.' %
|
||||
(name, sender or 'any'))
|
||||
|
||||
self._receivers.remove(key)
|
||||
self._receiver_list = [(n, r, s) for n, r, s in self._receiver_list
|
||||
if n != name and s != sender]
|
||||
|
||||
def __call__(self, name=None, sender=None):
|
||||
def decorator(fn):
|
||||
self.connect(fn, name, sender)
|
||||
return fn
|
||||
return decorator
|
||||
|
||||
def send(self, instance, *args, **kwargs):
|
||||
sender = type(instance)
|
||||
responses = []
|
||||
for n, r, s in self._receiver_list:
|
||||
if s is None or isinstance(instance, s):
|
||||
responses.append((r, r(sender, instance, *args, **kwargs)))
|
||||
return responses
|
||||
|
||||
|
||||
pre_save = Signal()
|
||||
post_save = Signal()
|
||||
pre_delete = Signal()
|
||||
post_delete = Signal()
|
||||
pre_init = Signal()
|
||||
|
||||
|
||||
class Model(_Model):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(Model, self).__init__(*args, **kwargs)
|
||||
pre_init.send(self)
|
||||
|
||||
def save(self, *args, **kwargs):
|
||||
pk_value = self._pk if self._meta.primary_key else True
|
||||
created = kwargs.get('force_insert', False) or not bool(pk_value)
|
||||
pre_save.send(self, created=created)
|
||||
ret = super(Model, self).save(*args, **kwargs)
|
||||
post_save.send(self, created=created)
|
||||
return ret
|
||||
|
||||
def delete_instance(self, *args, **kwargs):
|
||||
pre_delete.send(self)
|
||||
ret = super(Model, self).delete_instance(*args, **kwargs)
|
||||
post_delete.send(self)
|
||||
return ret
|
|
@ -1,103 +0,0 @@
|
|||
"""
|
||||
Peewee integration with pysqlcipher.
|
||||
|
||||
Project page: https://github.com/leapcode/pysqlcipher/
|
||||
|
||||
**WARNING!!! EXPERIMENTAL!!!**
|
||||
|
||||
* Although this extention's code is short, it has not been properly
|
||||
peer-reviewed yet and may have introduced vulnerabilities.
|
||||
|
||||
Also note that this code relies on pysqlcipher and sqlcipher, and
|
||||
the code there might have vulnerabilities as well, but since these
|
||||
are widely used crypto modules, we can expect "short zero days" there.
|
||||
|
||||
Example usage:
|
||||
|
||||
from peewee.playground.ciphersql_ext import SqlCipherDatabase
|
||||
db = SqlCipherDatabase('/path/to/my.db', passphrase="don'tuseme4real")
|
||||
|
||||
* `passphrase`: should be "long enough".
|
||||
Note that *length beats vocabulary* (much exponential), and even
|
||||
a lowercase-only passphrase like easytorememberyethardforotherstoguess
|
||||
packs more noise than 8 random printable characters and *can* be memorized.
|
||||
|
||||
When opening an existing database, passphrase should be the one used when the
|
||||
database was created. If the passphrase is incorrect, an exception will only be
|
||||
raised **when you access the database**.
|
||||
|
||||
If you need to ask for an interactive passphrase, here's example code you can
|
||||
put after the `db = ...` line:
|
||||
|
||||
try: # Just access the database so that it checks the encryption.
|
||||
db.get_tables()
|
||||
# We're looking for a DatabaseError with a specific error message.
|
||||
except peewee.DatabaseError as e:
|
||||
# Check whether the message *means* "passphrase is wrong"
|
||||
if e.args[0] == 'file is encrypted or is not a database':
|
||||
raise Exception('Developer should Prompt user for passphrase '
|
||||
'again.')
|
||||
else:
|
||||
# A different DatabaseError. Raise it.
|
||||
raise e
|
||||
|
||||
See a more elaborate example with this code at
|
||||
https://gist.github.com/thedod/11048875
|
||||
"""
|
||||
import datetime
|
||||
import decimal
|
||||
import sys
|
||||
|
||||
from peewee import *
|
||||
from playhouse.sqlite_ext import SqliteExtDatabase
|
||||
if sys.version_info[0] != 3:
|
||||
from pysqlcipher import dbapi2 as sqlcipher
|
||||
else:
|
||||
try:
|
||||
from sqlcipher3 import dbapi2 as sqlcipher
|
||||
except ImportError:
|
||||
from pysqlcipher3 import dbapi2 as sqlcipher
|
||||
|
||||
sqlcipher.register_adapter(decimal.Decimal, str)
|
||||
sqlcipher.register_adapter(datetime.date, str)
|
||||
sqlcipher.register_adapter(datetime.time, str)
|
||||
|
||||
|
||||
class _SqlCipherDatabase(object):
|
||||
def _connect(self):
|
||||
params = dict(self.connect_params)
|
||||
passphrase = params.pop('passphrase', '').replace("'", "''")
|
||||
|
||||
conn = sqlcipher.connect(self.database, isolation_level=None, **params)
|
||||
try:
|
||||
if passphrase:
|
||||
conn.execute("PRAGMA key='%s'" % passphrase)
|
||||
self._add_conn_hooks(conn)
|
||||
except:
|
||||
conn.close()
|
||||
raise
|
||||
return conn
|
||||
|
||||
def set_passphrase(self, passphrase):
|
||||
if not self.is_closed():
|
||||
raise ImproperlyConfigured('Cannot set passphrase when database '
|
||||
'is open. To change passphrase of an '
|
||||
'open database use the rekey() method.')
|
||||
|
||||
self.connect_params['passphrase'] = passphrase
|
||||
|
||||
def rekey(self, passphrase):
|
||||
if self.is_closed():
|
||||
self.connect()
|
||||
|
||||
self.execute_sql("PRAGMA rekey='%s'" % passphrase.replace("'", "''"))
|
||||
self.connect_params['passphrase'] = passphrase
|
||||
return True
|
||||
|
||||
|
||||
class SqlCipherDatabase(_SqlCipherDatabase, SqliteDatabase):
|
||||
pass
|
||||
|
||||
|
||||
class SqlCipherExtDatabase(_SqlCipherDatabase, SqliteExtDatabase):
|
||||
pass
|
File diff suppressed because it is too large
Load diff
|
@ -1,522 +0,0 @@
|
|||
import datetime
|
||||
import hashlib
|
||||
import heapq
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import sys
|
||||
import threading
|
||||
import zlib
|
||||
try:
|
||||
from collections import Counter
|
||||
except ImportError:
|
||||
Counter = None
|
||||
try:
|
||||
from urlparse import urlparse
|
||||
except ImportError:
|
||||
from urllib.parse import urlparse
|
||||
|
||||
try:
|
||||
from playhouse._sqlite_ext import TableFunction
|
||||
except ImportError:
|
||||
TableFunction = None
|
||||
|
||||
|
||||
SQLITE_DATETIME_FORMATS = (
|
||||
'%Y-%m-%d %H:%M:%S',
|
||||
'%Y-%m-%d %H:%M:%S.%f',
|
||||
'%Y-%m-%d',
|
||||
'%H:%M:%S',
|
||||
'%H:%M:%S.%f',
|
||||
'%H:%M')
|
||||
|
||||
from peewee import format_date_time
|
||||
|
||||
def format_date_time_sqlite(date_value):
|
||||
return format_date_time(date_value, SQLITE_DATETIME_FORMATS)
|
||||
|
||||
try:
|
||||
from playhouse import _sqlite_udf as cython_udf
|
||||
except ImportError:
|
||||
cython_udf = None
|
||||
|
||||
|
||||
# Group udf by function.
|
||||
CONTROL_FLOW = 'control_flow'
|
||||
DATE = 'date'
|
||||
FILE = 'file'
|
||||
HELPER = 'helpers'
|
||||
MATH = 'math'
|
||||
STRING = 'string'
|
||||
|
||||
AGGREGATE_COLLECTION = {}
|
||||
TABLE_FUNCTION_COLLECTION = {}
|
||||
UDF_COLLECTION = {}
|
||||
|
||||
|
||||
class synchronized_dict(dict):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(synchronized_dict, self).__init__(*args, **kwargs)
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def __getitem__(self, key):
|
||||
with self._lock:
|
||||
return super(synchronized_dict, self).__getitem__(key)
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
with self._lock:
|
||||
return super(synchronized_dict, self).__setitem__(key, value)
|
||||
|
||||
def __delitem__(self, key):
|
||||
with self._lock:
|
||||
return super(synchronized_dict, self).__delitem__(key)
|
||||
|
||||
|
||||
STATE = synchronized_dict()
|
||||
SETTINGS = synchronized_dict()
|
||||
|
||||
# Class and function decorators.
|
||||
def aggregate(*groups):
|
||||
def decorator(klass):
|
||||
for group in groups:
|
||||
AGGREGATE_COLLECTION.setdefault(group, [])
|
||||
AGGREGATE_COLLECTION[group].append(klass)
|
||||
return klass
|
||||
return decorator
|
||||
|
||||
def table_function(*groups):
|
||||
def decorator(klass):
|
||||
for group in groups:
|
||||
TABLE_FUNCTION_COLLECTION.setdefault(group, [])
|
||||
TABLE_FUNCTION_COLLECTION[group].append(klass)
|
||||
return klass
|
||||
return decorator
|
||||
|
||||
def udf(*groups):
|
||||
def decorator(fn):
|
||||
for group in groups:
|
||||
UDF_COLLECTION.setdefault(group, [])
|
||||
UDF_COLLECTION[group].append(fn)
|
||||
return fn
|
||||
return decorator
|
||||
|
||||
# Register aggregates / functions with connection.
|
||||
def register_aggregate_groups(db, *groups):
|
||||
seen = set()
|
||||
for group in groups:
|
||||
klasses = AGGREGATE_COLLECTION.get(group, ())
|
||||
for klass in klasses:
|
||||
name = getattr(klass, 'name', klass.__name__)
|
||||
if name not in seen:
|
||||
seen.add(name)
|
||||
db.register_aggregate(klass, name)
|
||||
|
||||
def register_table_function_groups(db, *groups):
|
||||
seen = set()
|
||||
for group in groups:
|
||||
klasses = TABLE_FUNCTION_COLLECTION.get(group, ())
|
||||
for klass in klasses:
|
||||
if klass.name not in seen:
|
||||
seen.add(klass.name)
|
||||
db.register_table_function(klass)
|
||||
|
||||
def register_udf_groups(db, *groups):
|
||||
seen = set()
|
||||
for group in groups:
|
||||
functions = UDF_COLLECTION.get(group, ())
|
||||
for function in functions:
|
||||
name = function.__name__
|
||||
if name not in seen:
|
||||
seen.add(name)
|
||||
db.register_function(function, name)
|
||||
|
||||
def register_groups(db, *groups):
|
||||
register_aggregate_groups(db, *groups)
|
||||
register_table_function_groups(db, *groups)
|
||||
register_udf_groups(db, *groups)
|
||||
|
||||
def register_all(db):
|
||||
register_aggregate_groups(db, *AGGREGATE_COLLECTION)
|
||||
register_table_function_groups(db, *TABLE_FUNCTION_COLLECTION)
|
||||
register_udf_groups(db, *UDF_COLLECTION)
|
||||
|
||||
|
||||
# Begin actual user-defined functions and aggregates.
|
||||
|
||||
# Scalar functions.
|
||||
@udf(CONTROL_FLOW)
|
||||
def if_then_else(cond, truthy, falsey=None):
|
||||
if cond:
|
||||
return truthy
|
||||
return falsey
|
||||
|
||||
@udf(DATE)
|
||||
def strip_tz(date_str):
|
||||
date_str = date_str.replace('T', ' ')
|
||||
tz_idx1 = date_str.find('+')
|
||||
if tz_idx1 != -1:
|
||||
return date_str[:tz_idx1]
|
||||
tz_idx2 = date_str.find('-')
|
||||
if tz_idx2 > 13:
|
||||
return date_str[:tz_idx2]
|
||||
return date_str
|
||||
|
||||
@udf(DATE)
|
||||
def human_delta(nseconds, glue=', '):
|
||||
parts = (
|
||||
(86400 * 365, 'year'),
|
||||
(86400 * 30, 'month'),
|
||||
(86400 * 7, 'week'),
|
||||
(86400, 'day'),
|
||||
(3600, 'hour'),
|
||||
(60, 'minute'),
|
||||
(1, 'second'),
|
||||
)
|
||||
accum = []
|
||||
for offset, name in parts:
|
||||
val, nseconds = divmod(nseconds, offset)
|
||||
if val:
|
||||
suffix = val != 1 and 's' or ''
|
||||
accum.append('%s %s%s' % (val, name, suffix))
|
||||
if not accum:
|
||||
return '0 seconds'
|
||||
return glue.join(accum)
|
||||
|
||||
@udf(FILE)
|
||||
def file_ext(filename):
|
||||
try:
|
||||
res = os.path.splitext(filename)
|
||||
except ValueError:
|
||||
return None
|
||||
return res[1]
|
||||
|
||||
@udf(FILE)
|
||||
def file_read(filename):
|
||||
try:
|
||||
with open(filename) as fh:
|
||||
return fh.read()
|
||||
except:
|
||||
pass
|
||||
|
||||
if sys.version_info[0] == 2:
|
||||
@udf(HELPER)
|
||||
def gzip(data, compression=9):
|
||||
return buffer(zlib.compress(data, compression))
|
||||
|
||||
@udf(HELPER)
|
||||
def gunzip(data):
|
||||
return zlib.decompress(data)
|
||||
else:
|
||||
@udf(HELPER)
|
||||
def gzip(data, compression=9):
|
||||
if isinstance(data, str):
|
||||
data = bytes(data.encode('raw_unicode_escape'))
|
||||
return zlib.compress(data, compression)
|
||||
|
||||
@udf(HELPER)
|
||||
def gunzip(data):
|
||||
return zlib.decompress(data)
|
||||
|
||||
@udf(HELPER)
|
||||
def hostname(url):
|
||||
parse_result = urlparse(url)
|
||||
if parse_result:
|
||||
return parse_result.netloc
|
||||
|
||||
@udf(HELPER)
|
||||
def toggle(key):
|
||||
key = key.lower()
|
||||
STATE[key] = ret = not STATE.get(key)
|
||||
return ret
|
||||
|
||||
@udf(HELPER)
|
||||
def setting(key, value=None):
|
||||
if value is None:
|
||||
return SETTINGS.get(key)
|
||||
else:
|
||||
SETTINGS[key] = value
|
||||
return value
|
||||
|
||||
@udf(HELPER)
|
||||
def clear_settings():
|
||||
SETTINGS.clear()
|
||||
|
||||
@udf(HELPER)
|
||||
def clear_toggles():
|
||||
STATE.clear()
|
||||
|
||||
@udf(MATH)
|
||||
def randomrange(start, end=None, step=None):
|
||||
if end is None:
|
||||
start, end = 0, start
|
||||
elif step is None:
|
||||
step = 1
|
||||
return random.randrange(start, end, step)
|
||||
|
||||
@udf(MATH)
|
||||
def gauss_distribution(mean, sigma):
|
||||
try:
|
||||
return random.gauss(mean, sigma)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
@udf(MATH)
|
||||
def sqrt(n):
|
||||
try:
|
||||
return math.sqrt(n)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
@udf(MATH)
|
||||
def tonumber(s):
|
||||
try:
|
||||
return int(s)
|
||||
except ValueError:
|
||||
try:
|
||||
return float(s)
|
||||
except:
|
||||
return None
|
||||
|
||||
@udf(STRING)
|
||||
def substr_count(haystack, needle):
|
||||
if not haystack or not needle:
|
||||
return 0
|
||||
return haystack.count(needle)
|
||||
|
||||
@udf(STRING)
|
||||
def strip_chars(haystack, chars):
|
||||
return haystack.strip(chars)
|
||||
|
||||
def _hash(constructor, *args):
|
||||
hash_obj = constructor()
|
||||
for arg in args:
|
||||
hash_obj.update(arg)
|
||||
return hash_obj.hexdigest()
|
||||
|
||||
# Aggregates.
|
||||
class _heap_agg(object):
|
||||
def __init__(self):
|
||||
self.heap = []
|
||||
self.ct = 0
|
||||
|
||||
def process(self, value):
|
||||
return value
|
||||
|
||||
def step(self, value):
|
||||
self.ct += 1
|
||||
heapq.heappush(self.heap, self.process(value))
|
||||
|
||||
class _datetime_heap_agg(_heap_agg):
|
||||
def process(self, value):
|
||||
return format_date_time_sqlite(value)
|
||||
|
||||
if sys.version_info[:2] == (2, 6):
|
||||
def total_seconds(td):
|
||||
return (td.seconds +
|
||||
(td.days * 86400) +
|
||||
(td.microseconds / (10.**6)))
|
||||
else:
|
||||
total_seconds = lambda td: td.total_seconds()
|
||||
|
||||
@aggregate(DATE)
|
||||
class mintdiff(_datetime_heap_agg):
|
||||
def finalize(self):
|
||||
dtp = min_diff = None
|
||||
while self.heap:
|
||||
if min_diff is None:
|
||||
if dtp is None:
|
||||
dtp = heapq.heappop(self.heap)
|
||||
continue
|
||||
dt = heapq.heappop(self.heap)
|
||||
diff = dt - dtp
|
||||
if min_diff is None or min_diff > diff:
|
||||
min_diff = diff
|
||||
dtp = dt
|
||||
if min_diff is not None:
|
||||
return total_seconds(min_diff)
|
||||
|
||||
@aggregate(DATE)
|
||||
class avgtdiff(_datetime_heap_agg):
|
||||
def finalize(self):
|
||||
if self.ct < 1:
|
||||
return
|
||||
elif self.ct == 1:
|
||||
return 0
|
||||
|
||||
total = ct = 0
|
||||
dtp = None
|
||||
while self.heap:
|
||||
if total == 0:
|
||||
if dtp is None:
|
||||
dtp = heapq.heappop(self.heap)
|
||||
continue
|
||||
|
||||
dt = heapq.heappop(self.heap)
|
||||
diff = dt - dtp
|
||||
ct += 1
|
||||
total += total_seconds(diff)
|
||||
dtp = dt
|
||||
|
||||
return float(total) / ct
|
||||
|
||||
@aggregate(DATE)
|
||||
class duration(object):
|
||||
def __init__(self):
|
||||
self._min = self._max = None
|
||||
|
||||
def step(self, value):
|
||||
dt = format_date_time_sqlite(value)
|
||||
if self._min is None or dt < self._min:
|
||||
self._min = dt
|
||||
if self._max is None or dt > self._max:
|
||||
self._max = dt
|
||||
|
||||
def finalize(self):
|
||||
if self._min and self._max:
|
||||
td = (self._max - self._min)
|
||||
return total_seconds(td)
|
||||
return None
|
||||
|
||||
@aggregate(MATH)
|
||||
class mode(object):
|
||||
if Counter:
|
||||
def __init__(self):
|
||||
self.items = Counter()
|
||||
|
||||
def step(self, *args):
|
||||
self.items.update(args)
|
||||
|
||||
def finalize(self):
|
||||
if self.items:
|
||||
return self.items.most_common(1)[0][0]
|
||||
else:
|
||||
def __init__(self):
|
||||
self.items = []
|
||||
|
||||
def step(self, item):
|
||||
self.items.append(item)
|
||||
|
||||
def finalize(self):
|
||||
if self.items:
|
||||
return max(set(self.items), key=self.items.count)
|
||||
|
||||
@aggregate(MATH)
|
||||
class minrange(_heap_agg):
|
||||
def finalize(self):
|
||||
if self.ct == 0:
|
||||
return
|
||||
elif self.ct == 1:
|
||||
return 0
|
||||
|
||||
prev = min_diff = None
|
||||
|
||||
while self.heap:
|
||||
if min_diff is None:
|
||||
if prev is None:
|
||||
prev = heapq.heappop(self.heap)
|
||||
continue
|
||||
curr = heapq.heappop(self.heap)
|
||||
diff = curr - prev
|
||||
if min_diff is None or min_diff > diff:
|
||||
min_diff = diff
|
||||
prev = curr
|
||||
return min_diff
|
||||
|
||||
@aggregate(MATH)
|
||||
class avgrange(_heap_agg):
|
||||
def finalize(self):
|
||||
if self.ct == 0:
|
||||
return
|
||||
elif self.ct == 1:
|
||||
return 0
|
||||
|
||||
total = ct = 0
|
||||
prev = None
|
||||
while self.heap:
|
||||
if total == 0:
|
||||
if prev is None:
|
||||
prev = heapq.heappop(self.heap)
|
||||
continue
|
||||
|
||||
curr = heapq.heappop(self.heap)
|
||||
diff = curr - prev
|
||||
ct += 1
|
||||
total += diff
|
||||
prev = curr
|
||||
|
||||
return float(total) / ct
|
||||
|
||||
@aggregate(MATH)
|
||||
class _range(object):
|
||||
name = 'range'
|
||||
|
||||
def __init__(self):
|
||||
self._min = self._max = None
|
||||
|
||||
def step(self, value):
|
||||
if self._min is None or value < self._min:
|
||||
self._min = value
|
||||
if self._max is None or value > self._max:
|
||||
self._max = value
|
||||
|
||||
def finalize(self):
|
||||
if self._min is not None and self._max is not None:
|
||||
return self._max - self._min
|
||||
return None
|
||||
|
||||
|
||||
if cython_udf is not None:
|
||||
damerau_levenshtein_dist = udf(STRING)(cython_udf.damerau_levenshtein_dist)
|
||||
levenshtein_dist = udf(STRING)(cython_udf.levenshtein_dist)
|
||||
str_dist = udf(STRING)(cython_udf.str_dist)
|
||||
median = aggregate(MATH)(cython_udf.median)
|
||||
|
||||
|
||||
if TableFunction is not None:
|
||||
@table_function(STRING)
|
||||
class RegexSearch(TableFunction):
|
||||
params = ['regex', 'search_string']
|
||||
columns = ['match']
|
||||
name = 'regex_search'
|
||||
|
||||
def initialize(self, regex=None, search_string=None):
|
||||
self._iter = re.finditer(regex, search_string)
|
||||
|
||||
def iterate(self, idx):
|
||||
return (next(self._iter).group(0),)
|
||||
|
||||
@table_function(DATE)
|
||||
class DateSeries(TableFunction):
|
||||
params = ['start', 'stop', 'step_seconds']
|
||||
columns = ['date']
|
||||
name = 'date_series'
|
||||
|
||||
def initialize(self, start, stop, step_seconds=86400):
|
||||
self.start = format_date_time_sqlite(start)
|
||||
self.stop = format_date_time_sqlite(stop)
|
||||
step_seconds = int(step_seconds)
|
||||
self.step_seconds = datetime.timedelta(seconds=step_seconds)
|
||||
|
||||
if (self.start.hour == 0 and
|
||||
self.start.minute == 0 and
|
||||
self.start.second == 0 and
|
||||
step_seconds >= 86400):
|
||||
self.format = '%Y-%m-%d'
|
||||
elif (self.start.year == 1900 and
|
||||
self.start.month == 1 and
|
||||
self.start.day == 1 and
|
||||
self.stop.year == 1900 and
|
||||
self.stop.month == 1 and
|
||||
self.stop.day == 1 and
|
||||
step_seconds < 86400):
|
||||
self.format = '%H:%M:%S'
|
||||
else:
|
||||
self.format = '%Y-%m-%d %H:%M:%S'
|
||||
|
||||
def iterate(self, idx):
|
||||
if self.start > self.stop:
|
||||
raise StopIteration
|
||||
current = self.start
|
||||
self.start += self.step_seconds
|
||||
return (current.strftime(self.format),)
|
|
@ -1,330 +0,0 @@
|
|||
import logging
|
||||
import weakref
|
||||
from threading import local as thread_local
|
||||
from threading import Event
|
||||
from threading import Thread
|
||||
try:
|
||||
from Queue import Queue
|
||||
except ImportError:
|
||||
from queue import Queue
|
||||
|
||||
try:
|
||||
import gevent
|
||||
from gevent import Greenlet as GThread
|
||||
from gevent.event import Event as GEvent
|
||||
from gevent.local import local as greenlet_local
|
||||
from gevent.queue import Queue as GQueue
|
||||
except ImportError:
|
||||
GThread = GQueue = GEvent = None
|
||||
|
||||
from peewee import SENTINEL
|
||||
from playhouse.sqlite_ext import SqliteExtDatabase
|
||||
|
||||
|
||||
logger = logging.getLogger('peewee.sqliteq')
|
||||
|
||||
|
||||
class ResultTimeout(Exception):
|
||||
pass
|
||||
|
||||
class WriterPaused(Exception):
|
||||
pass
|
||||
|
||||
class ShutdownException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class AsyncCursor(object):
|
||||
__slots__ = ('sql', 'params', 'commit', 'timeout',
|
||||
'_event', '_cursor', '_exc', '_idx', '_rows', '_ready')
|
||||
|
||||
def __init__(self, event, sql, params, commit, timeout):
|
||||
self._event = event
|
||||
self.sql = sql
|
||||
self.params = params
|
||||
self.commit = commit
|
||||
self.timeout = timeout
|
||||
self._cursor = self._exc = self._idx = self._rows = None
|
||||
self._ready = False
|
||||
|
||||
def set_result(self, cursor, exc=None):
|
||||
self._cursor = cursor
|
||||
self._exc = exc
|
||||
self._idx = 0
|
||||
self._rows = cursor.fetchall() if exc is None else []
|
||||
self._event.set()
|
||||
return self
|
||||
|
||||
def _wait(self, timeout=None):
|
||||
timeout = timeout if timeout is not None else self.timeout
|
||||
if not self._event.wait(timeout=timeout) and timeout:
|
||||
raise ResultTimeout('results not ready, timed out.')
|
||||
if self._exc is not None:
|
||||
raise self._exc
|
||||
self._ready = True
|
||||
|
||||
def __iter__(self):
|
||||
if not self._ready:
|
||||
self._wait()
|
||||
if self._exc is not None:
|
||||
raise self._exec
|
||||
return self
|
||||
|
||||
def next(self):
|
||||
if not self._ready:
|
||||
self._wait()
|
||||
try:
|
||||
obj = self._rows[self._idx]
|
||||
except IndexError:
|
||||
raise StopIteration
|
||||
else:
|
||||
self._idx += 1
|
||||
return obj
|
||||
__next__ = next
|
||||
|
||||
@property
|
||||
def lastrowid(self):
|
||||
if not self._ready:
|
||||
self._wait()
|
||||
return self._cursor.lastrowid
|
||||
|
||||
@property
|
||||
def rowcount(self):
|
||||
if not self._ready:
|
||||
self._wait()
|
||||
return self._cursor.rowcount
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return self._cursor.description
|
||||
|
||||
def close(self):
|
||||
self._cursor.close()
|
||||
|
||||
def fetchall(self):
|
||||
return list(self) # Iterating implies waiting until populated.
|
||||
|
||||
def fetchone(self):
|
||||
if not self._ready:
|
||||
self._wait()
|
||||
try:
|
||||
return next(self)
|
||||
except StopIteration:
|
||||
return None
|
||||
|
||||
SHUTDOWN = StopIteration
|
||||
PAUSE = object()
|
||||
UNPAUSE = object()
|
||||
|
||||
|
||||
class Writer(object):
|
||||
__slots__ = ('database', 'queue')
|
||||
|
||||
def __init__(self, database, queue):
|
||||
self.database = database
|
||||
self.queue = queue
|
||||
|
||||
def run(self):
|
||||
conn = self.database.connection()
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
if conn is None: # Paused.
|
||||
if self.wait_unpause():
|
||||
conn = self.database.connection()
|
||||
else:
|
||||
conn = self.loop(conn)
|
||||
except ShutdownException:
|
||||
logger.info('writer received shutdown request, exiting.')
|
||||
return
|
||||
finally:
|
||||
if conn is not None:
|
||||
self.database._close(conn)
|
||||
self.database._state.reset()
|
||||
|
||||
def wait_unpause(self):
|
||||
obj = self.queue.get()
|
||||
if obj is UNPAUSE:
|
||||
logger.info('writer unpaused - reconnecting to database.')
|
||||
return True
|
||||
elif obj is SHUTDOWN:
|
||||
raise ShutdownException()
|
||||
elif obj is PAUSE:
|
||||
logger.error('writer received pause, but is already paused.')
|
||||
else:
|
||||
obj.set_result(None, WriterPaused())
|
||||
logger.warning('writer paused, not handling %s', obj)
|
||||
|
||||
def loop(self, conn):
|
||||
obj = self.queue.get()
|
||||
if isinstance(obj, AsyncCursor):
|
||||
self.execute(obj)
|
||||
elif obj is PAUSE:
|
||||
logger.info('writer paused - closing database connection.')
|
||||
self.database._close(conn)
|
||||
self.database._state.reset()
|
||||
return
|
||||
elif obj is UNPAUSE:
|
||||
logger.error('writer received unpause, but is already running.')
|
||||
elif obj is SHUTDOWN:
|
||||
raise ShutdownException()
|
||||
else:
|
||||
logger.error('writer received unsupported object: %s', obj)
|
||||
return conn
|
||||
|
||||
def execute(self, obj):
|
||||
logger.debug('received query %s', obj.sql)
|
||||
try:
|
||||
cursor = self.database._execute(obj.sql, obj.params, obj.commit)
|
||||
except Exception as execute_err:
|
||||
cursor = None
|
||||
exc = execute_err # python3 is so fucking lame.
|
||||
else:
|
||||
exc = None
|
||||
return obj.set_result(cursor, exc)
|
||||
|
||||
|
||||
class SqliteQueueDatabase(SqliteExtDatabase):
|
||||
WAL_MODE_ERROR_MESSAGE = ('SQLite must be configured to use the WAL '
|
||||
'journal mode when using this feature. WAL mode '
|
||||
'allows one or more readers to continue reading '
|
||||
'while another connection writes to the '
|
||||
'database.')
|
||||
|
||||
def __init__(self, database, use_gevent=False, autostart=True,
|
||||
queue_max_size=None, results_timeout=None, *args, **kwargs):
|
||||
kwargs['check_same_thread'] = False
|
||||
|
||||
# Ensure that journal_mode is WAL. This value is passed to the parent
|
||||
# class constructor below.
|
||||
pragmas = self._validate_journal_mode(kwargs.pop('pragmas', None))
|
||||
|
||||
# Reference to execute_sql on the parent class. Since we've overridden
|
||||
# execute_sql(), this is just a handy way to reference the real
|
||||
# implementation.
|
||||
Parent = super(SqliteQueueDatabase, self)
|
||||
self._execute = Parent.execute_sql
|
||||
|
||||
# Call the parent class constructor with our modified pragmas.
|
||||
Parent.__init__(database, pragmas=pragmas, *args, **kwargs)
|
||||
|
||||
self._autostart = autostart
|
||||
self._results_timeout = results_timeout
|
||||
self._is_stopped = True
|
||||
|
||||
# Get different objects depending on the threading implementation.
|
||||
self._thread_helper = self.get_thread_impl(use_gevent)(queue_max_size)
|
||||
|
||||
# Create the writer thread, optionally starting it.
|
||||
self._create_write_queue()
|
||||
if self._autostart:
|
||||
self.start()
|
||||
|
||||
def get_thread_impl(self, use_gevent):
|
||||
return GreenletHelper if use_gevent else ThreadHelper
|
||||
|
||||
def _validate_journal_mode(self, pragmas=None):
|
||||
if pragmas:
|
||||
pdict = dict((k.lower(), v) for (k, v) in pragmas)
|
||||
if pdict.get('journal_mode', 'wal').lower() != 'wal':
|
||||
raise ValueError(self.WAL_MODE_ERROR_MESSAGE)
|
||||
|
||||
return [(k, v) for (k, v) in pragmas
|
||||
if k != 'journal_mode'] + [('journal_mode', 'wal')]
|
||||
else:
|
||||
return [('journal_mode', 'wal')]
|
||||
|
||||
def _create_write_queue(self):
|
||||
self._write_queue = self._thread_helper.queue()
|
||||
|
||||
def queue_size(self):
|
||||
return self._write_queue.qsize()
|
||||
|
||||
def execute_sql(self, sql, params=None, commit=SENTINEL, timeout=None):
|
||||
if commit is SENTINEL:
|
||||
commit = not sql.lower().startswith('select')
|
||||
|
||||
if not commit:
|
||||
return self._execute(sql, params, commit=commit)
|
||||
|
||||
cursor = AsyncCursor(
|
||||
event=self._thread_helper.event(),
|
||||
sql=sql,
|
||||
params=params,
|
||||
commit=commit,
|
||||
timeout=self._results_timeout if timeout is None else timeout)
|
||||
self._write_queue.put(cursor)
|
||||
return cursor
|
||||
|
||||
def start(self):
|
||||
with self._lock:
|
||||
if not self._is_stopped:
|
||||
return False
|
||||
def run():
|
||||
writer = Writer(self, self._write_queue)
|
||||
writer.run()
|
||||
|
||||
self._writer = self._thread_helper.thread(run)
|
||||
self._writer.start()
|
||||
self._is_stopped = False
|
||||
return True
|
||||
|
||||
def stop(self):
|
||||
logger.debug('environment stop requested.')
|
||||
with self._lock:
|
||||
if self._is_stopped:
|
||||
return False
|
||||
self._write_queue.put(SHUTDOWN)
|
||||
self._writer.join()
|
||||
self._is_stopped = True
|
||||
return True
|
||||
|
||||
def is_stopped(self):
|
||||
with self._lock:
|
||||
return self._is_stopped
|
||||
|
||||
def pause(self):
|
||||
with self._lock:
|
||||
self._write_queue.put(PAUSE)
|
||||
|
||||
def unpause(self):
|
||||
with self._lock:
|
||||
self._write_queue.put(UNPAUSE)
|
||||
|
||||
def __unsupported__(self, *args, **kwargs):
|
||||
raise ValueError('This method is not supported by %r.' % type(self))
|
||||
atomic = transaction = savepoint = __unsupported__
|
||||
|
||||
|
||||
class ThreadHelper(object):
|
||||
__slots__ = ('queue_max_size',)
|
||||
|
||||
def __init__(self, queue_max_size=None):
|
||||
self.queue_max_size = queue_max_size
|
||||
|
||||
def event(self): return Event()
|
||||
|
||||
def queue(self, max_size=None):
|
||||
max_size = max_size if max_size is not None else self.queue_max_size
|
||||
return Queue(maxsize=max_size or 0)
|
||||
|
||||
def thread(self, fn, *args, **kwargs):
|
||||
thread = Thread(target=fn, args=args, kwargs=kwargs)
|
||||
thread.daemon = True
|
||||
return thread
|
||||
|
||||
|
||||
class GreenletHelper(ThreadHelper):
|
||||
__slots__ = ()
|
||||
|
||||
def event(self): return GEvent()
|
||||
|
||||
def queue(self, max_size=None):
|
||||
max_size = max_size if max_size is not None else self.queue_max_size
|
||||
return GQueue(maxsize=max_size or 0)
|
||||
|
||||
def thread(self, fn, *args, **kwargs):
|
||||
def wrap(*a, **k):
|
||||
gevent.sleep()
|
||||
return fn(*a, **k)
|
||||
return GThread(wrap, *args, **kwargs)
|
|
@ -1,62 +0,0 @@
|
|||
from functools import wraps
|
||||
import logging
|
||||
|
||||
|
||||
logger = logging.getLogger('peewee')
|
||||
|
||||
|
||||
class _QueryLogHandler(logging.Handler):
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.queries = []
|
||||
logging.Handler.__init__(self, *args, **kwargs)
|
||||
|
||||
def emit(self, record):
|
||||
self.queries.append(record)
|
||||
|
||||
|
||||
class count_queries(object):
|
||||
def __init__(self, only_select=False):
|
||||
self.only_select = only_select
|
||||
self.count = 0
|
||||
|
||||
def get_queries(self):
|
||||
return self._handler.queries
|
||||
|
||||
def __enter__(self):
|
||||
self._handler = _QueryLogHandler()
|
||||
logger.setLevel(logging.DEBUG)
|
||||
logger.addHandler(self._handler)
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
logger.removeHandler(self._handler)
|
||||
if self.only_select:
|
||||
self.count = len([q for q in self._handler.queries
|
||||
if q.msg[0].startswith('SELECT ')])
|
||||
else:
|
||||
self.count = len(self._handler.queries)
|
||||
|
||||
|
||||
class assert_query_count(count_queries):
|
||||
def __init__(self, expected, only_select=False):
|
||||
super(assert_query_count, self).__init__(only_select=only_select)
|
||||
self.expected = expected
|
||||
|
||||
def __call__(self, f):
|
||||
@wraps(f)
|
||||
def decorated(*args, **kwds):
|
||||
with self:
|
||||
ret = f(*args, **kwds)
|
||||
|
||||
self._assert_count()
|
||||
return ret
|
||||
|
||||
return decorated
|
||||
|
||||
def _assert_count(self):
|
||||
error_msg = '%s != %s' % (self.count, self.expected)
|
||||
assert self.count == self.expected, error_msg
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
super(assert_query_count, self).__exit__(exc_type, exc_val, exc_tb)
|
||||
self._assert_count()
|
|
@ -8,12 +8,25 @@ See the datetime section of the Python Library Reference for information
|
|||
on how to use these modules.
|
||||
'''
|
||||
|
||||
import sys
|
||||
import datetime
|
||||
import os.path
|
||||
|
||||
from pytz.exceptions import AmbiguousTimeError
|
||||
from pytz.exceptions import InvalidTimeError
|
||||
from pytz.exceptions import NonExistentTimeError
|
||||
from pytz.exceptions import UnknownTimeZoneError
|
||||
from pytz.lazy import LazyDict, LazyList, LazySet # noqa
|
||||
from pytz.tzinfo import unpickler, BaseTzInfo
|
||||
from pytz.tzfile import build_tzinfo
|
||||
|
||||
|
||||
# The IANA (nee Olson) database is updated several times a year.
|
||||
OLSON_VERSION = '2017b'
|
||||
VERSION = '2017.2' # Switching to pip compatible version numbering.
|
||||
OLSON_VERSION = '2019c'
|
||||
VERSION = '2019.3' # pip compatible version number.
|
||||
__version__ = VERSION
|
||||
|
||||
OLSEN_VERSION = OLSON_VERSION # Old releases had this misspelling
|
||||
OLSEN_VERSION = OLSON_VERSION # Old releases had this misspelling
|
||||
|
||||
__all__ = [
|
||||
'timezone', 'utc', 'country_timezones', 'country_names',
|
||||
|
@ -21,23 +34,11 @@ __all__ = [
|
|||
'NonExistentTimeError', 'UnknownTimeZoneError',
|
||||
'all_timezones', 'all_timezones_set',
|
||||
'common_timezones', 'common_timezones_set',
|
||||
]
|
||||
|
||||
import sys, datetime, os.path, gettext
|
||||
|
||||
from pytz.exceptions import AmbiguousTimeError
|
||||
from pytz.exceptions import InvalidTimeError
|
||||
from pytz.exceptions import NonExistentTimeError
|
||||
from pytz.exceptions import UnknownTimeZoneError
|
||||
from pytz.lazy import LazyDict, LazyList, LazySet
|
||||
from pytz.tzinfo import unpickler
|
||||
from pytz.tzfile import build_tzinfo, _byte_string
|
||||
'BaseTzInfo',
|
||||
]
|
||||
|
||||
|
||||
try:
|
||||
unicode
|
||||
|
||||
except NameError: # Python 3.x
|
||||
if sys.version_info[0] > 2: # Python 3.x
|
||||
|
||||
# Python 3.x doesn't have unicode(), making writing code
|
||||
# for Python 2.3 and Python 3.x a pain.
|
||||
|
@ -52,10 +53,13 @@ except NameError: # Python 3.x
|
|||
...
|
||||
UnicodeEncodeError: ...
|
||||
"""
|
||||
s.encode('ASCII') # Raise an exception if not ASCII
|
||||
return s # But return the original string - not a byte string.
|
||||
if type(s) == bytes:
|
||||
s = s.decode('ASCII')
|
||||
else:
|
||||
s.encode('ASCII') # Raise an exception if not ASCII
|
||||
return s # But the string - not a byte string.
|
||||
|
||||
else: # Python 2.x
|
||||
else: # Python 2.x
|
||||
|
||||
def ascii(s):
|
||||
r"""
|
||||
|
@ -76,24 +80,31 @@ def open_resource(name):
|
|||
|
||||
Uses the pkg_resources module if available and no standard file
|
||||
found at the calculated location.
|
||||
|
||||
It is possible to specify different location for zoneinfo
|
||||
subdir by using the PYTZ_TZDATADIR environment variable.
|
||||
"""
|
||||
name_parts = name.lstrip('/').split('/')
|
||||
for part in name_parts:
|
||||
if part == os.path.pardir or os.path.sep in part:
|
||||
raise ValueError('Bad path segment: %r' % part)
|
||||
filename = os.path.join(os.path.dirname(__file__),
|
||||
'zoneinfo', *name_parts)
|
||||
if not os.path.exists(filename):
|
||||
# http://bugs.launchpad.net/bugs/383171 - we avoid using this
|
||||
# unless absolutely necessary to help when a broken version of
|
||||
# pkg_resources is installed.
|
||||
try:
|
||||
from pkg_resources import resource_stream
|
||||
except ImportError:
|
||||
resource_stream = None
|
||||
zoneinfo_dir = os.environ.get('PYTZ_TZDATADIR', None)
|
||||
if zoneinfo_dir is not None:
|
||||
filename = os.path.join(zoneinfo_dir, *name_parts)
|
||||
else:
|
||||
filename = os.path.join(os.path.dirname(__file__),
|
||||
'zoneinfo', *name_parts)
|
||||
if not os.path.exists(filename):
|
||||
# http://bugs.launchpad.net/bugs/383171 - we avoid using this
|
||||
# unless absolutely necessary to help when a broken version of
|
||||
# pkg_resources is installed.
|
||||
try:
|
||||
from pkg_resources import resource_stream
|
||||
except ImportError:
|
||||
resource_stream = None
|
||||
|
||||
if resource_stream is not None:
|
||||
return resource_stream(__name__, 'zoneinfo/' + name)
|
||||
if resource_stream is not None:
|
||||
return resource_stream(__name__, 'zoneinfo/' + name)
|
||||
return open(filename, 'rb')
|
||||
|
||||
|
||||
|
@ -106,23 +117,9 @@ def resource_exists(name):
|
|||
return False
|
||||
|
||||
|
||||
# Enable this when we get some translations?
|
||||
# We want an i18n API that is useful to programs using Python's gettext
|
||||
# module, as well as the Zope3 i18n package. Perhaps we should just provide
|
||||
# the POT file and translations, and leave it up to callers to make use
|
||||
# of them.
|
||||
#
|
||||
# t = gettext.translation(
|
||||
# 'pytz', os.path.join(os.path.dirname(__file__), 'locales'),
|
||||
# fallback=True
|
||||
# )
|
||||
# def _(timezone_name):
|
||||
# """Translate a timezone name using the current locale, returning Unicode"""
|
||||
# return t.ugettext(timezone_name)
|
||||
|
||||
|
||||
_tzinfo_cache = {}
|
||||
|
||||
|
||||
def timezone(zone):
|
||||
r''' Return a datetime.tzinfo implementation for the given timezone
|
||||
|
||||
|
@ -160,6 +157,9 @@ def timezone(zone):
|
|||
Unknown
|
||||
|
||||
'''
|
||||
if zone is None:
|
||||
raise UnknownTimeZoneError(None)
|
||||
|
||||
if zone.upper() == 'UTC':
|
||||
return utc
|
||||
|
||||
|
@ -169,9 +169,9 @@ def timezone(zone):
|
|||
# All valid timezones are ASCII
|
||||
raise UnknownTimeZoneError(zone)
|
||||
|
||||
zone = _unmunge_zone(zone)
|
||||
zone = _case_insensitive_zone_lookup(_unmunge_zone(zone))
|
||||
if zone not in _tzinfo_cache:
|
||||
if zone in all_timezones_set:
|
||||
if zone in all_timezones_set: # noqa
|
||||
fp = open_resource(zone)
|
||||
try:
|
||||
_tzinfo_cache[zone] = build_tzinfo(zone, fp)
|
||||
|
@ -188,11 +188,22 @@ def _unmunge_zone(zone):
|
|||
return zone.replace('_plus_', '+').replace('_minus_', '-')
|
||||
|
||||
|
||||
_all_timezones_lower_to_standard = None
|
||||
|
||||
|
||||
def _case_insensitive_zone_lookup(zone):
|
||||
"""case-insensitively matching timezone, else return zone unchanged"""
|
||||
global _all_timezones_lower_to_standard
|
||||
if _all_timezones_lower_to_standard is None:
|
||||
_all_timezones_lower_to_standard = dict((tz.lower(), tz) for tz in all_timezones) # noqa
|
||||
return _all_timezones_lower_to_standard.get(zone.lower()) or zone # noqa
|
||||
|
||||
|
||||
ZERO = datetime.timedelta(0)
|
||||
HOUR = datetime.timedelta(hours=1)
|
||||
|
||||
|
||||
class UTC(datetime.tzinfo):
|
||||
class UTC(BaseTzInfo):
|
||||
"""UTC
|
||||
|
||||
Optimized UTC implementation. It unpickles using the single module global
|
||||
|
@ -275,6 +286,8 @@ def _UTC():
|
|||
False
|
||||
"""
|
||||
return utc
|
||||
|
||||
|
||||
_UTC.__safe_for_unpickling__ = True
|
||||
|
||||
|
||||
|
@ -285,9 +298,10 @@ def _p(*args):
|
|||
by shortening the path.
|
||||
"""
|
||||
return unpickler(*args)
|
||||
_p.__safe_for_unpickling__ = True
|
||||
|
||||
|
||||
_p.__safe_for_unpickling__ = True
|
||||
|
||||
|
||||
class _CountryTimezoneDict(LazyDict):
|
||||
"""Map ISO 3166 country code to a list of timezone names commonly used
|
||||
|
@ -334,7 +348,7 @@ class _CountryTimezoneDict(LazyDict):
|
|||
if line.startswith('#'):
|
||||
continue
|
||||
code, coordinates, zone = line.split(None, 4)[:3]
|
||||
if zone not in all_timezones_set:
|
||||
if zone not in all_timezones_set: # noqa
|
||||
continue
|
||||
try:
|
||||
data[code].append(zone)
|
||||
|
@ -344,6 +358,7 @@ class _CountryTimezoneDict(LazyDict):
|
|||
finally:
|
||||
zone_tab.close()
|
||||
|
||||
|
||||
country_timezones = _CountryTimezoneDict()
|
||||
|
||||
|
||||
|
@ -367,6 +382,7 @@ class _CountryNameDict(LazyDict):
|
|||
finally:
|
||||
zone_tab.close()
|
||||
|
||||
|
||||
country_names = _CountryNameDict()
|
||||
|
||||
|
||||
|
@ -374,7 +390,7 @@ country_names = _CountryNameDict()
|
|||
|
||||
class _FixedOffset(datetime.tzinfo):
|
||||
|
||||
zone = None # to match the standard pytz API
|
||||
zone = None # to match the standard pytz API
|
||||
|
||||
def __init__(self, minutes):
|
||||
if abs(minutes) >= 1440:
|
||||
|
@ -412,24 +428,24 @@ class _FixedOffset(datetime.tzinfo):
|
|||
return dt.astimezone(self)
|
||||
|
||||
|
||||
def FixedOffset(offset, _tzinfos = {}):
|
||||
def FixedOffset(offset, _tzinfos={}):
|
||||
"""return a fixed-offset timezone based off a number of minutes.
|
||||
|
||||
>>> one = FixedOffset(-330)
|
||||
>>> one
|
||||
pytz.FixedOffset(-330)
|
||||
>>> one.utcoffset(datetime.datetime.now())
|
||||
datetime.timedelta(-1, 66600)
|
||||
>>> one.dst(datetime.datetime.now())
|
||||
datetime.timedelta(0)
|
||||
>>> str(one.utcoffset(datetime.datetime.now()))
|
||||
'-1 day, 18:30:00'
|
||||
>>> str(one.dst(datetime.datetime.now()))
|
||||
'0:00:00'
|
||||
|
||||
>>> two = FixedOffset(1380)
|
||||
>>> two
|
||||
pytz.FixedOffset(1380)
|
||||
>>> two.utcoffset(datetime.datetime.now())
|
||||
datetime.timedelta(0, 82800)
|
||||
>>> two.dst(datetime.datetime.now())
|
||||
datetime.timedelta(0)
|
||||
>>> str(two.utcoffset(datetime.datetime.now()))
|
||||
'23:00:00'
|
||||
>>> str(two.dst(datetime.datetime.now()))
|
||||
'0:00:00'
|
||||
|
||||
The datetime.timedelta must be between the range of -1 and 1 day,
|
||||
non-inclusive.
|
||||
|
@ -478,18 +494,19 @@ def FixedOffset(offset, _tzinfos = {}):
|
|||
|
||||
return info
|
||||
|
||||
|
||||
FixedOffset.__safe_for_unpickling__ = True
|
||||
|
||||
|
||||
def _test():
|
||||
import doctest, os, sys
|
||||
import doctest
|
||||
sys.path.insert(0, os.pardir)
|
||||
import pytz
|
||||
return doctest.testmod(pytz)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
_test()
|
||||
|
||||
all_timezones = \
|
||||
['Africa/Abidjan',
|
||||
'Africa/Accra',
|
||||
|
@ -792,6 +809,7 @@ all_timezones = \
|
|||
'Asia/Pontianak',
|
||||
'Asia/Pyongyang',
|
||||
'Asia/Qatar',
|
||||
'Asia/Qostanay',
|
||||
'Asia/Qyzylorda',
|
||||
'Asia/Rangoon',
|
||||
'Asia/Riyadh',
|
||||
|
@ -865,7 +883,6 @@ all_timezones = \
|
|||
'CST6CDT',
|
||||
'Canada/Atlantic',
|
||||
'Canada/Central',
|
||||
'Canada/East-Saskatchewan',
|
||||
'Canada/Eastern',
|
||||
'Canada/Mountain',
|
||||
'Canada/Newfoundland',
|
||||
|
@ -1077,7 +1094,6 @@ all_timezones = \
|
|||
'US/Michigan',
|
||||
'US/Mountain',
|
||||
'US/Pacific',
|
||||
'US/Pacific-New',
|
||||
'US/Samoa',
|
||||
'UTC',
|
||||
'Universal',
|
||||
|
@ -1358,6 +1374,7 @@ common_timezones = \
|
|||
'Asia/Pontianak',
|
||||
'Asia/Pyongyang',
|
||||
'Asia/Qatar',
|
||||
'Asia/Qostanay',
|
||||
'Asia/Qyzylorda',
|
||||
'Asia/Riyadh',
|
||||
'Asia/Sakhalin',
|
||||
|
|
|
@ -5,7 +5,7 @@ Custom exceptions raised by pytz.
|
|||
__all__ = [
|
||||
'UnknownTimeZoneError', 'InvalidTimeError', 'AmbiguousTimeError',
|
||||
'NonExistentTimeError',
|
||||
]
|
||||
]
|
||||
|
||||
|
||||
class UnknownTimeZoneError(KeyError):
|
||||
|
|
|
@ -1,8 +1,11 @@
|
|||
from threading import RLock
|
||||
try:
|
||||
from UserDict import DictMixin
|
||||
except ImportError:
|
||||
from collections import Mapping as DictMixin
|
||||
from collections.abc import Mapping as DictMixin
|
||||
except ImportError: # Python < 3.3
|
||||
try:
|
||||
from UserDict import DictMixin # Python 2
|
||||
except ImportError: # Python 3.0-3.3
|
||||
from collections import Mapping as DictMixin
|
||||
|
||||
|
||||
# With lazy loading, we might end up with multiple threads triggering
|
||||
|
@ -13,6 +16,7 @@ _fill_lock = RLock()
|
|||
class LazyDict(DictMixin):
|
||||
"""Dictionary populated on first use."""
|
||||
data = None
|
||||
|
||||
def __getitem__(self, key):
|
||||
if self.data is None:
|
||||
_fill_lock.acquire()
|
||||
|
|
|
@ -5,17 +5,28 @@ Used for testing against as they are only correct for the years
|
|||
'''
|
||||
|
||||
from datetime import tzinfo, timedelta, datetime
|
||||
from pytz import utc, UTC, HOUR, ZERO
|
||||
from pytz import HOUR, ZERO, UTC
|
||||
|
||||
__all__ = [
|
||||
'FixedOffset',
|
||||
'LocalTimezone',
|
||||
'USTimeZone',
|
||||
'Eastern',
|
||||
'Central',
|
||||
'Mountain',
|
||||
'Pacific',
|
||||
'UTC'
|
||||
]
|
||||
|
||||
|
||||
# A class building tzinfo objects for fixed-offset time zones.
|
||||
# Note that FixedOffset(0, "UTC") is a different way to build a
|
||||
# UTC tzinfo object.
|
||||
|
||||
class FixedOffset(tzinfo):
|
||||
"""Fixed offset in minutes east from UTC."""
|
||||
|
||||
def __init__(self, offset, name):
|
||||
self.__offset = timedelta(minutes = offset)
|
||||
self.__offset = timedelta(minutes=offset)
|
||||
self.__name = name
|
||||
|
||||
def utcoffset(self, dt):
|
||||
|
@ -27,18 +38,19 @@ class FixedOffset(tzinfo):
|
|||
def dst(self, dt):
|
||||
return ZERO
|
||||
|
||||
# A class capturing the platform's idea of local time.
|
||||
|
||||
import time as _time
|
||||
|
||||
STDOFFSET = timedelta(seconds = -_time.timezone)
|
||||
STDOFFSET = timedelta(seconds=-_time.timezone)
|
||||
if _time.daylight:
|
||||
DSTOFFSET = timedelta(seconds = -_time.altzone)
|
||||
DSTOFFSET = timedelta(seconds=-_time.altzone)
|
||||
else:
|
||||
DSTOFFSET = STDOFFSET
|
||||
|
||||
DSTDIFF = DSTOFFSET - STDOFFSET
|
||||
|
||||
|
||||
# A class capturing the platform's idea of local time.
|
||||
class LocalTimezone(tzinfo):
|
||||
|
||||
def utcoffset(self, dt):
|
||||
|
@ -66,7 +78,6 @@ class LocalTimezone(tzinfo):
|
|||
|
||||
Local = LocalTimezone()
|
||||
|
||||
# A complete implementation of current DST rules for major US time zones.
|
||||
|
||||
def first_sunday_on_or_after(dt):
|
||||
days_to_go = 6 - dt.weekday()
|
||||
|
@ -74,12 +85,15 @@ def first_sunday_on_or_after(dt):
|
|||
dt += timedelta(days_to_go)
|
||||
return dt
|
||||
|
||||
|
||||
# In the US, DST starts at 2am (standard time) on the first Sunday in April.
|
||||
DSTSTART = datetime(1, 4, 1, 2)
|
||||
# and ends at 2am (DST time; 1am standard time) on the last Sunday of Oct.
|
||||
# which is the first Sunday on or after Oct 25.
|
||||
DSTEND = datetime(1, 10, 25, 1)
|
||||
|
||||
|
||||
# A complete implementation of current DST rules for major US time zones.
|
||||
class USTimeZone(tzinfo):
|
||||
|
||||
def __init__(self, hours, reprname, stdname, dstname):
|
||||
|
@ -120,8 +134,7 @@ class USTimeZone(tzinfo):
|
|||
else:
|
||||
return ZERO
|
||||
|
||||
Eastern = USTimeZone(-5, "Eastern", "EST", "EDT")
|
||||
Central = USTimeZone(-6, "Central", "CST", "CDT")
|
||||
Eastern = USTimeZone(-5, "Eastern", "EST", "EDT")
|
||||
Central = USTimeZone(-6, "Central", "CST", "CDT")
|
||||
Mountain = USTimeZone(-7, "Mountain", "MST", "MDT")
|
||||
Pacific = USTimeZone(-8, "Pacific", "PST", "PDT")
|
||||
|
||||
Pacific = USTimeZone(-8, "Pacific", "PST", "PDT")
|
||||
|
|
|
@ -1,34 +0,0 @@
|
|||
# -*- coding: ascii -*-
|
||||
|
||||
from doctest import DocFileSuite
|
||||
import unittest, os.path, sys
|
||||
|
||||
THIS_DIR = os.path.dirname(__file__)
|
||||
|
||||
README = os.path.join(THIS_DIR, os.pardir, os.pardir, 'README.txt')
|
||||
|
||||
|
||||
class DocumentationTestCase(unittest.TestCase):
|
||||
def test_readme_encoding(self):
|
||||
'''Confirm the README.txt is pure ASCII.'''
|
||||
f = open(README, 'rb')
|
||||
try:
|
||||
f.read().decode('ASCII')
|
||||
finally:
|
||||
f.close()
|
||||
|
||||
|
||||
def test_suite():
|
||||
"For the Z3 test runner"
|
||||
return unittest.TestSuite((
|
||||
DocumentationTestCase('test_readme_encoding'),
|
||||
DocFileSuite(os.path.join(os.pardir, os.pardir, 'README.txt'))))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
sys.path.insert(0, os.path.abspath(os.path.join(
|
||||
THIS_DIR, os.pardir, os.pardir
|
||||
)))
|
||||
unittest.main(defaultTest='test_suite')
|
||||
|
||||
|
|
@ -1,313 +0,0 @@
|
|||
from operator import *
|
||||
import os.path
|
||||
import sys
|
||||
import unittest
|
||||
import warnings
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# Only munge path if invoked as a script. Testrunners should have setup
|
||||
# the paths already
|
||||
sys.path.insert(0, os.path.abspath(os.path.join(os.pardir, os.pardir)))
|
||||
|
||||
|
||||
from pytz.lazy import LazyList, LazySet
|
||||
|
||||
|
||||
class LazyListTestCase(unittest.TestCase):
|
||||
initial_data = [3,2,1]
|
||||
|
||||
def setUp(self):
|
||||
self.base = [3, 2, 1]
|
||||
self.lesser = [2, 1, 0]
|
||||
self.greater = [4, 3, 2]
|
||||
|
||||
self.lazy = LazyList(iter(list(self.base)))
|
||||
|
||||
def test_unary_ops(self):
|
||||
unary_ops = [str, repr, len, bool, not_]
|
||||
try:
|
||||
unary_ops.append(unicode)
|
||||
except NameError:
|
||||
pass # unicode no longer exists in Python 3.
|
||||
|
||||
for op in unary_ops:
|
||||
self.assertEqual(
|
||||
op(self.lazy),
|
||||
op(self.base), str(op))
|
||||
|
||||
def test_binary_ops(self):
|
||||
binary_ops = [eq, ge, gt, le, lt, ne, add, concat]
|
||||
try:
|
||||
binary_ops.append(cmp)
|
||||
except NameError:
|
||||
pass # cmp no longer exists in Python 3.
|
||||
|
||||
for op in binary_ops:
|
||||
self.assertEqual(
|
||||
op(self.lazy, self.lazy),
|
||||
op(self.base, self.base), str(op))
|
||||
for other in [self.base, self.lesser, self.greater]:
|
||||
self.assertEqual(
|
||||
op(self.lazy, other),
|
||||
op(self.base, other), '%s %s' % (op, other))
|
||||
self.assertEqual(
|
||||
op(other, self.lazy),
|
||||
op(other, self.base), '%s %s' % (op, other))
|
||||
|
||||
# Multiplication
|
||||
self.assertEqual(self.lazy * 3, self.base * 3)
|
||||
self.assertEqual(3 * self.lazy, 3 * self.base)
|
||||
|
||||
# Contains
|
||||
self.assertTrue(2 in self.lazy)
|
||||
self.assertFalse(42 in self.lazy)
|
||||
|
||||
def test_iadd(self):
|
||||
self.lazy += [1]
|
||||
self.base += [1]
|
||||
self.assertEqual(self.lazy, self.base)
|
||||
|
||||
def test_bool(self):
|
||||
self.assertTrue(bool(self.lazy))
|
||||
self.assertFalse(bool(LazyList()))
|
||||
self.assertFalse(bool(LazyList(iter([]))))
|
||||
|
||||
def test_hash(self):
|
||||
self.assertRaises(TypeError, hash, self.lazy)
|
||||
|
||||
def test_isinstance(self):
|
||||
self.assertTrue(isinstance(self.lazy, list))
|
||||
self.assertFalse(isinstance(self.lazy, tuple))
|
||||
|
||||
def test_callable(self):
|
||||
try:
|
||||
callable
|
||||
except NameError:
|
||||
return # No longer exists with Python 3.
|
||||
self.assertFalse(callable(self.lazy))
|
||||
|
||||
def test_append(self):
|
||||
self.base.append('extra')
|
||||
self.lazy.append('extra')
|
||||
self.assertEqual(self.lazy, self.base)
|
||||
|
||||
def test_count(self):
|
||||
self.assertEqual(self.lazy.count(2), 1)
|
||||
|
||||
def test_index(self):
|
||||
self.assertEqual(self.lazy.index(2), 1)
|
||||
|
||||
def test_extend(self):
|
||||
self.base.extend([6, 7])
|
||||
self.lazy.extend([6, 7])
|
||||
self.assertEqual(self.lazy, self.base)
|
||||
|
||||
def test_insert(self):
|
||||
self.base.insert(0, 'ping')
|
||||
self.lazy.insert(0, 'ping')
|
||||
self.assertEqual(self.lazy, self.base)
|
||||
|
||||
def test_pop(self):
|
||||
self.assertEqual(self.lazy.pop(), self.base.pop())
|
||||
self.assertEqual(self.lazy, self.base)
|
||||
|
||||
def test_remove(self):
|
||||
self.base.remove(2)
|
||||
self.lazy.remove(2)
|
||||
self.assertEqual(self.lazy, self.base)
|
||||
|
||||
def test_reverse(self):
|
||||
self.base.reverse()
|
||||
self.lazy.reverse()
|
||||
self.assertEqual(self.lazy, self.base)
|
||||
|
||||
def test_reversed(self):
|
||||
self.assertEqual(list(reversed(self.lazy)), list(reversed(self.base)))
|
||||
|
||||
def test_sort(self):
|
||||
self.base.sort()
|
||||
self.assertNotEqual(self.lazy, self.base, 'Test data already sorted')
|
||||
self.lazy.sort()
|
||||
self.assertEqual(self.lazy, self.base)
|
||||
|
||||
def test_sorted(self):
|
||||
self.assertEqual(sorted(self.lazy), sorted(self.base))
|
||||
|
||||
def test_getitem(self):
|
||||
for idx in range(-len(self.base), len(self.base)):
|
||||
self.assertEqual(self.lazy[idx], self.base[idx])
|
||||
|
||||
def test_setitem(self):
|
||||
for idx in range(-len(self.base), len(self.base)):
|
||||
self.base[idx] = idx + 1000
|
||||
self.assertNotEqual(self.lazy, self.base)
|
||||
self.lazy[idx] = idx + 1000
|
||||
self.assertEqual(self.lazy, self.base)
|
||||
|
||||
def test_delitem(self):
|
||||
del self.base[0]
|
||||
self.assertNotEqual(self.lazy, self.base)
|
||||
del self.lazy[0]
|
||||
self.assertEqual(self.lazy, self.base)
|
||||
|
||||
del self.base[-2]
|
||||
self.assertNotEqual(self.lazy, self.base)
|
||||
del self.lazy[-2]
|
||||
self.assertEqual(self.lazy, self.base)
|
||||
|
||||
def test_iter(self):
|
||||
self.assertEqual(list(iter(self.lazy)), list(iter(self.base)))
|
||||
|
||||
def test_getslice(self):
|
||||
for i in range(-len(self.base), len(self.base)):
|
||||
for j in range(-len(self.base), len(self.base)):
|
||||
for step in [-1, 1]:
|
||||
self.assertEqual(self.lazy[i:j:step], self.base[i:j:step])
|
||||
|
||||
def test_setslice(self):
|
||||
for i in range(-len(self.base), len(self.base)):
|
||||
for j in range(-len(self.base), len(self.base)):
|
||||
for step in [-1, 1]:
|
||||
replacement = range(0, len(self.base[i:j:step]))
|
||||
self.base[i:j:step] = replacement
|
||||
self.lazy[i:j:step] = replacement
|
||||
self.assertEqual(self.lazy, self.base)
|
||||
|
||||
def test_delslice(self):
|
||||
del self.base[0:1]
|
||||
del self.lazy[0:1]
|
||||
self.assertEqual(self.lazy, self.base)
|
||||
|
||||
del self.base[-1:1:-1]
|
||||
del self.lazy[-1:1:-1]
|
||||
self.assertEqual(self.lazy, self.base)
|
||||
|
||||
|
||||
class LazySetTestCase(unittest.TestCase):
|
||||
initial_data = set([3,2,1])
|
||||
|
||||
def setUp(self):
|
||||
self.base = set([3, 2, 1])
|
||||
self.lazy = LazySet(iter(set(self.base)))
|
||||
|
||||
def test_unary_ops(self):
|
||||
# These ops just need to work.
|
||||
unary_ops = [str, repr]
|
||||
try:
|
||||
unary_ops.append(unicode)
|
||||
except NameError:
|
||||
pass # unicode no longer exists in Python 3.
|
||||
|
||||
for op in unary_ops:
|
||||
op(self.lazy) # These ops just need to work.
|
||||
|
||||
# These ops should return identical values as a real set.
|
||||
unary_ops = [len, bool, not_]
|
||||
|
||||
for op in unary_ops:
|
||||
self.assertEqual(
|
||||
op(self.lazy),
|
||||
op(self.base), '%s(lazy) == %r' % (op, op(self.lazy)))
|
||||
|
||||
def test_binary_ops(self):
|
||||
binary_ops = [eq, ge, gt, le, lt, ne, sub, and_, or_, xor]
|
||||
try:
|
||||
binary_ops.append(cmp)
|
||||
except NameError:
|
||||
pass # cmp no longer exists in Python 3.
|
||||
|
||||
for op in binary_ops:
|
||||
self.assertEqual(
|
||||
op(self.lazy, self.lazy),
|
||||
op(self.base, self.base), str(op))
|
||||
self.assertEqual(
|
||||
op(self.lazy, self.base),
|
||||
op(self.base, self.base), str(op))
|
||||
self.assertEqual(
|
||||
op(self.base, self.lazy),
|
||||
op(self.base, self.base), str(op))
|
||||
|
||||
# Contains
|
||||
self.assertTrue(2 in self.lazy)
|
||||
self.assertFalse(42 in self.lazy)
|
||||
|
||||
def test_iops(self):
|
||||
try:
|
||||
iops = [isub, iand, ior, ixor]
|
||||
except NameError:
|
||||
return # Don't exist in older Python versions.
|
||||
for op in iops:
|
||||
# Mutating operators, so make fresh copies.
|
||||
lazy = LazySet(self.base)
|
||||
base = self.base.copy()
|
||||
op(lazy, set([1]))
|
||||
op(base, set([1]))
|
||||
self.assertEqual(lazy, base, str(op))
|
||||
|
||||
def test_bool(self):
|
||||
self.assertTrue(bool(self.lazy))
|
||||
self.assertFalse(bool(LazySet()))
|
||||
self.assertFalse(bool(LazySet(iter([]))))
|
||||
|
||||
def test_hash(self):
|
||||
self.assertRaises(TypeError, hash, self.lazy)
|
||||
|
||||
def test_isinstance(self):
|
||||
self.assertTrue(isinstance(self.lazy, set))
|
||||
|
||||
def test_callable(self):
|
||||
try:
|
||||
callable
|
||||
except NameError:
|
||||
return # No longer exists with Python 3.
|
||||
self.assertFalse(callable(self.lazy))
|
||||
|
||||
def test_add(self):
|
||||
self.base.add('extra')
|
||||
self.lazy.add('extra')
|
||||
self.assertEqual(self.lazy, self.base)
|
||||
|
||||
def test_copy(self):
|
||||
self.assertEqual(self.lazy.copy(), self.base)
|
||||
|
||||
def test_method_ops(self):
|
||||
ops = [
|
||||
'difference', 'intersection', 'isdisjoint',
|
||||
'issubset', 'issuperset', 'symmetric_difference', 'union',
|
||||
'difference_update', 'intersection_update',
|
||||
'symmetric_difference_update', 'update']
|
||||
for op in ops:
|
||||
if not hasattr(set, op):
|
||||
continue # Not in this version of Python.
|
||||
# Make a copy, as some of the ops are mutating.
|
||||
lazy = LazySet(set(self.base))
|
||||
base = set(self.base)
|
||||
self.assertEqual(
|
||||
getattr(self.lazy, op)(set([1])),
|
||||
getattr(self.base, op)(set([1])), op)
|
||||
self.assertEqual(self.lazy, self.base, op)
|
||||
|
||||
def test_discard(self):
|
||||
self.base.discard(1)
|
||||
self.assertNotEqual(self.lazy, self.base)
|
||||
self.lazy.discard(1)
|
||||
self.assertEqual(self.lazy, self.base)
|
||||
|
||||
def test_pop(self):
|
||||
self.assertEqual(self.lazy.pop(), self.base.pop())
|
||||
self.assertEqual(self.lazy, self.base)
|
||||
|
||||
def test_remove(self):
|
||||
self.base.remove(2)
|
||||
self.lazy.remove(2)
|
||||
self.assertEqual(self.lazy, self.base)
|
||||
|
||||
def test_clear(self):
|
||||
self.lazy.clear()
|
||||
self.assertEqual(self.lazy, set())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
warnings.simplefilter("error") # Warnings should be fatal in tests.
|
||||
unittest.main()
|
|
@ -1,844 +0,0 @@
|
|||
# -*- coding: ascii -*-
|
||||
|
||||
import sys, os, os.path
|
||||
import unittest, doctest
|
||||
try:
|
||||
import cPickle as pickle
|
||||
except ImportError:
|
||||
import pickle
|
||||
from datetime import datetime, time, timedelta, tzinfo
|
||||
import warnings
|
||||
|
||||
if __name__ == '__main__':
|
||||
# Only munge path if invoked as a script. Testrunners should have setup
|
||||
# the paths already
|
||||
sys.path.insert(0, os.path.abspath(os.path.join(os.pardir, os.pardir)))
|
||||
|
||||
import pytz
|
||||
from pytz import reference
|
||||
from pytz.tzfile import _byte_string
|
||||
from pytz.tzinfo import DstTzInfo, StaticTzInfo
|
||||
|
||||
# I test for expected version to ensure the correct version of pytz is
|
||||
# actually being tested.
|
||||
EXPECTED_VERSION='2017.2'
|
||||
EXPECTED_OLSON_VERSION='2017b'
|
||||
|
||||
fmt = '%Y-%m-%d %H:%M:%S %Z%z'
|
||||
|
||||
NOTIME = timedelta(0)
|
||||
|
||||
# GMT is a tzinfo.StaticTzInfo--the class we primarily want to test--while
|
||||
# UTC is reference implementation. They both have the same timezone meaning.
|
||||
UTC = pytz.timezone('UTC')
|
||||
GMT = pytz.timezone('GMT')
|
||||
assert isinstance(GMT, StaticTzInfo), 'GMT is no longer a StaticTzInfo'
|
||||
|
||||
def prettydt(dt):
|
||||
"""datetime as a string using a known format.
|
||||
|
||||
We don't use strftime as it doesn't handle years earlier than 1900
|
||||
per http://bugs.python.org/issue1777412
|
||||
"""
|
||||
if dt.utcoffset() >= timedelta(0):
|
||||
offset = '+%s' % (dt.utcoffset(),)
|
||||
else:
|
||||
offset = '-%s' % (-1 * dt.utcoffset(),)
|
||||
return '%04d-%02d-%02d %02d:%02d:%02d %s %s' % (
|
||||
dt.year, dt.month, dt.day,
|
||||
dt.hour, dt.minute, dt.second,
|
||||
dt.tzname(), offset)
|
||||
|
||||
|
||||
try:
|
||||
unicode
|
||||
except NameError:
|
||||
# Python 3.x doesn't have unicode(), making writing code
|
||||
# for Python 2.3 and Python 3.x a pain.
|
||||
unicode = str
|
||||
|
||||
|
||||
class BasicTest(unittest.TestCase):
|
||||
|
||||
def testVersion(self):
|
||||
# Ensuring the correct version of pytz has been loaded
|
||||
self.assertEqual(EXPECTED_VERSION, pytz.__version__,
|
||||
'Incorrect pytz version loaded. Import path is stuffed '
|
||||
'or this test needs updating. (Wanted %s, got %s)'
|
||||
% (EXPECTED_VERSION, pytz.__version__))
|
||||
|
||||
self.assertEqual(EXPECTED_OLSON_VERSION, pytz.OLSON_VERSION,
|
||||
'Incorrect pytz version loaded. Import path is stuffed '
|
||||
'or this test needs updating. (Wanted %s, got %s)'
|
||||
% (EXPECTED_OLSON_VERSION, pytz.OLSON_VERSION))
|
||||
|
||||
def testGMT(self):
|
||||
now = datetime.now(tz=GMT)
|
||||
self.assertTrue(now.utcoffset() == NOTIME)
|
||||
self.assertTrue(now.dst() == NOTIME)
|
||||
self.assertTrue(now.timetuple() == now.utctimetuple())
|
||||
self.assertTrue(now==now.replace(tzinfo=UTC))
|
||||
|
||||
def testReferenceUTC(self):
|
||||
now = datetime.now(tz=UTC)
|
||||
self.assertTrue(now.utcoffset() == NOTIME)
|
||||
self.assertTrue(now.dst() == NOTIME)
|
||||
self.assertTrue(now.timetuple() == now.utctimetuple())
|
||||
|
||||
def testUnknownOffsets(self):
|
||||
# This tzinfo behavior is required to make
|
||||
# datetime.time.{utcoffset, dst, tzname} work as documented.
|
||||
|
||||
dst_tz = pytz.timezone('US/Eastern')
|
||||
|
||||
# This information is not known when we don't have a date,
|
||||
# so return None per API.
|
||||
self.assertTrue(dst_tz.utcoffset(None) is None)
|
||||
self.assertTrue(dst_tz.dst(None) is None)
|
||||
# We don't know the abbreviation, but this is still a valid
|
||||
# tzname per the Python documentation.
|
||||
self.assertEqual(dst_tz.tzname(None), 'US/Eastern')
|
||||
|
||||
def clearCache(self):
|
||||
pytz._tzinfo_cache.clear()
|
||||
|
||||
def testUnicodeTimezone(self):
|
||||
# We need to ensure that cold lookups work for both Unicode
|
||||
# and traditional strings, and that the desired singleton is
|
||||
# returned.
|
||||
self.clearCache()
|
||||
eastern = pytz.timezone(unicode('US/Eastern'))
|
||||
self.assertTrue(eastern is pytz.timezone('US/Eastern'))
|
||||
|
||||
self.clearCache()
|
||||
eastern = pytz.timezone('US/Eastern')
|
||||
self.assertTrue(eastern is pytz.timezone(unicode('US/Eastern')))
|
||||
|
||||
def testStaticTzInfo(self):
|
||||
# Ensure that static timezones are correctly detected,
|
||||
# per lp:1602807
|
||||
static = pytz.timezone('Etc/GMT-4')
|
||||
self.assertTrue(isinstance(static, StaticTzInfo))
|
||||
|
||||
|
||||
class PicklingTest(unittest.TestCase):
|
||||
|
||||
def _roundtrip_tzinfo(self, tz):
|
||||
p = pickle.dumps(tz)
|
||||
unpickled_tz = pickle.loads(p)
|
||||
self.assertTrue(tz is unpickled_tz, '%s did not roundtrip' % tz.zone)
|
||||
|
||||
def _roundtrip_datetime(self, dt):
|
||||
# Ensure that the tzinfo attached to a datetime instance
|
||||
# is identical to the one returned. This is important for
|
||||
# DST timezones, as some state is stored in the tzinfo.
|
||||
tz = dt.tzinfo
|
||||
p = pickle.dumps(dt)
|
||||
unpickled_dt = pickle.loads(p)
|
||||
unpickled_tz = unpickled_dt.tzinfo
|
||||
self.assertTrue(tz is unpickled_tz, '%s did not roundtrip' % tz.zone)
|
||||
|
||||
def testDst(self):
|
||||
tz = pytz.timezone('Europe/Amsterdam')
|
||||
dt = datetime(2004, 2, 1, 0, 0, 0)
|
||||
|
||||
for localized_tz in tz._tzinfos.values():
|
||||
self._roundtrip_tzinfo(localized_tz)
|
||||
self._roundtrip_datetime(dt.replace(tzinfo=localized_tz))
|
||||
|
||||
def testRoundtrip(self):
|
||||
dt = datetime(2004, 2, 1, 0, 0, 0)
|
||||
for zone in pytz.all_timezones:
|
||||
tz = pytz.timezone(zone)
|
||||
self._roundtrip_tzinfo(tz)
|
||||
|
||||
def testDatabaseFixes(self):
|
||||
# Hack the pickle to make it refer to a timezone abbreviation
|
||||
# that does not match anything. The unpickler should be able
|
||||
# to repair this case
|
||||
tz = pytz.timezone('Australia/Melbourne')
|
||||
p = pickle.dumps(tz)
|
||||
tzname = tz._tzname
|
||||
hacked_p = p.replace(_byte_string(tzname),
|
||||
_byte_string('?'*len(tzname)))
|
||||
self.assertNotEqual(p, hacked_p)
|
||||
unpickled_tz = pickle.loads(hacked_p)
|
||||
self.assertTrue(tz is unpickled_tz)
|
||||
|
||||
# Simulate a database correction. In this case, the incorrect
|
||||
# data will continue to be used.
|
||||
p = pickle.dumps(tz)
|
||||
new_utcoffset = tz._utcoffset.seconds + 42
|
||||
|
||||
# Python 3 introduced a new pickle protocol where numbers are stored in
|
||||
# hexadecimal representation. Here we extract the pickle
|
||||
# representation of the number for the current Python version.
|
||||
old_pickle_pattern = pickle.dumps(tz._utcoffset.seconds)[3:-1]
|
||||
new_pickle_pattern = pickle.dumps(new_utcoffset)[3:-1]
|
||||
hacked_p = p.replace(old_pickle_pattern, new_pickle_pattern)
|
||||
|
||||
self.assertNotEqual(p, hacked_p)
|
||||
unpickled_tz = pickle.loads(hacked_p)
|
||||
self.assertEqual(unpickled_tz._utcoffset.seconds, new_utcoffset)
|
||||
self.assertTrue(tz is not unpickled_tz)
|
||||
|
||||
def testOldPickles(self):
|
||||
# Ensure that applications serializing pytz instances as pickles
|
||||
# have no troubles upgrading to a new pytz release. These pickles
|
||||
# where created with pytz2006j
|
||||
east1 = pickle.loads(_byte_string(
|
||||
"cpytz\n_p\np1\n(S'US/Eastern'\np2\nI-18000\n"
|
||||
"I0\nS'EST'\np3\ntRp4\n."
|
||||
))
|
||||
east2 = pytz.timezone('US/Eastern').localize(
|
||||
datetime(2006, 1, 1)).tzinfo
|
||||
self.assertTrue(east1 is east2)
|
||||
|
||||
# Confirm changes in name munging between 2006j and 2007c cause
|
||||
# no problems.
|
||||
pap1 = pickle.loads(_byte_string(
|
||||
"cpytz\n_p\np1\n(S'America/Port_minus_au_minus_Prince'"
|
||||
"\np2\nI-17340\nI0\nS'PPMT'\np3\ntRp4\n."))
|
||||
pap2 = pytz.timezone('America/Port-au-Prince').localize(
|
||||
datetime(1910, 1, 1)).tzinfo
|
||||
self.assertTrue(pap1 is pap2)
|
||||
|
||||
gmt1 = pickle.loads(_byte_string(
|
||||
"cpytz\n_p\np1\n(S'Etc/GMT_plus_10'\np2\ntRp3\n."))
|
||||
gmt2 = pytz.timezone('Etc/GMT+10')
|
||||
self.assertTrue(gmt1 is gmt2)
|
||||
|
||||
|
||||
class USEasternDSTStartTestCase(unittest.TestCase):
|
||||
tzinfo = pytz.timezone('US/Eastern')
|
||||
|
||||
# 24 hours before DST changeover
|
||||
transition_time = datetime(2002, 4, 7, 7, 0, 0, tzinfo=UTC)
|
||||
|
||||
# Increase for 'flexible' DST transitions due to 1 minute granularity
|
||||
# of Python's datetime library
|
||||
instant = timedelta(seconds=1)
|
||||
|
||||
# before transition
|
||||
before = {
|
||||
'tzname': 'EST',
|
||||
'utcoffset': timedelta(hours = -5),
|
||||
'dst': timedelta(hours = 0),
|
||||
}
|
||||
|
||||
# after transition
|
||||
after = {
|
||||
'tzname': 'EDT',
|
||||
'utcoffset': timedelta(hours = -4),
|
||||
'dst': timedelta(hours = 1),
|
||||
}
|
||||
|
||||
def _test_tzname(self, utc_dt, wanted):
|
||||
tzname = wanted['tzname']
|
||||
dt = utc_dt.astimezone(self.tzinfo)
|
||||
self.assertEqual(dt.tzname(), tzname,
|
||||
'Expected %s as tzname for %s. Got %s' % (
|
||||
tzname, str(utc_dt), dt.tzname()
|
||||
)
|
||||
)
|
||||
|
||||
def _test_utcoffset(self, utc_dt, wanted):
|
||||
utcoffset = wanted['utcoffset']
|
||||
dt = utc_dt.astimezone(self.tzinfo)
|
||||
self.assertEqual(
|
||||
dt.utcoffset(), wanted['utcoffset'],
|
||||
'Expected %s as utcoffset for %s. Got %s' % (
|
||||
utcoffset, utc_dt, dt.utcoffset()
|
||||
)
|
||||
)
|
||||
|
||||
def _test_dst(self, utc_dt, wanted):
|
||||
dst = wanted['dst']
|
||||
dt = utc_dt.astimezone(self.tzinfo)
|
||||
self.assertEqual(dt.dst(),dst,
|
||||
'Expected %s as dst for %s. Got %s' % (
|
||||
dst, utc_dt, dt.dst()
|
||||
)
|
||||
)
|
||||
|
||||
def test_arithmetic(self):
|
||||
utc_dt = self.transition_time
|
||||
|
||||
for days in range(-420, 720, 20):
|
||||
delta = timedelta(days=days)
|
||||
|
||||
# Make sure we can get back where we started
|
||||
dt = utc_dt.astimezone(self.tzinfo)
|
||||
dt2 = dt + delta
|
||||
dt2 = dt2 - delta
|
||||
self.assertEqual(dt, dt2)
|
||||
|
||||
# Make sure arithmetic crossing DST boundaries ends
|
||||
# up in the correct timezone after normalization
|
||||
utc_plus_delta = (utc_dt + delta).astimezone(self.tzinfo)
|
||||
local_plus_delta = self.tzinfo.normalize(dt + delta)
|
||||
self.assertEqual(
|
||||
prettydt(utc_plus_delta),
|
||||
prettydt(local_plus_delta),
|
||||
'Incorrect result for delta==%d days. Wanted %r. Got %r'%(
|
||||
days,
|
||||
prettydt(utc_plus_delta),
|
||||
prettydt(local_plus_delta),
|
||||
)
|
||||
)
|
||||
|
||||
def _test_all(self, utc_dt, wanted):
|
||||
self._test_utcoffset(utc_dt, wanted)
|
||||
self._test_tzname(utc_dt, wanted)
|
||||
self._test_dst(utc_dt, wanted)
|
||||
|
||||
def testDayBefore(self):
|
||||
self._test_all(
|
||||
self.transition_time - timedelta(days=1), self.before
|
||||
)
|
||||
|
||||
def testTwoHoursBefore(self):
|
||||
self._test_all(
|
||||
self.transition_time - timedelta(hours=2), self.before
|
||||
)
|
||||
|
||||
def testHourBefore(self):
|
||||
self._test_all(
|
||||
self.transition_time - timedelta(hours=1), self.before
|
||||
)
|
||||
|
||||
def testInstantBefore(self):
|
||||
self._test_all(
|
||||
self.transition_time - self.instant, self.before
|
||||
)
|
||||
|
||||
def testTransition(self):
|
||||
self._test_all(
|
||||
self.transition_time, self.after
|
||||
)
|
||||
|
||||
def testInstantAfter(self):
|
||||
self._test_all(
|
||||
self.transition_time + self.instant, self.after
|
||||
)
|
||||
|
||||
def testHourAfter(self):
|
||||
self._test_all(
|
||||
self.transition_time + timedelta(hours=1), self.after
|
||||
)
|
||||
|
||||
def testTwoHoursAfter(self):
|
||||
self._test_all(
|
||||
self.transition_time + timedelta(hours=1), self.after
|
||||
)
|
||||
|
||||
def testDayAfter(self):
|
||||
self._test_all(
|
||||
self.transition_time + timedelta(days=1), self.after
|
||||
)
|
||||
|
||||
|
||||
class USEasternDSTEndTestCase(USEasternDSTStartTestCase):
|
||||
tzinfo = pytz.timezone('US/Eastern')
|
||||
transition_time = datetime(2002, 10, 27, 6, 0, 0, tzinfo=UTC)
|
||||
before = {
|
||||
'tzname': 'EDT',
|
||||
'utcoffset': timedelta(hours = -4),
|
||||
'dst': timedelta(hours = 1),
|
||||
}
|
||||
after = {
|
||||
'tzname': 'EST',
|
||||
'utcoffset': timedelta(hours = -5),
|
||||
'dst': timedelta(hours = 0),
|
||||
}
|
||||
|
||||
|
||||
class USEasternEPTStartTestCase(USEasternDSTStartTestCase):
|
||||
transition_time = datetime(1945, 8, 14, 23, 0, 0, tzinfo=UTC)
|
||||
before = {
|
||||
'tzname': 'EWT',
|
||||
'utcoffset': timedelta(hours = -4),
|
||||
'dst': timedelta(hours = 1),
|
||||
}
|
||||
after = {
|
||||
'tzname': 'EPT',
|
||||
'utcoffset': timedelta(hours = -4),
|
||||
'dst': timedelta(hours = 1),
|
||||
}
|
||||
|
||||
|
||||
class USEasternEPTEndTestCase(USEasternDSTStartTestCase):
|
||||
transition_time = datetime(1945, 9, 30, 6, 0, 0, tzinfo=UTC)
|
||||
before = {
|
||||
'tzname': 'EPT',
|
||||
'utcoffset': timedelta(hours = -4),
|
||||
'dst': timedelta(hours = 1),
|
||||
}
|
||||
after = {
|
||||
'tzname': 'EST',
|
||||
'utcoffset': timedelta(hours = -5),
|
||||
'dst': timedelta(hours = 0),
|
||||
}
|
||||
|
||||
|
||||
class WarsawWMTEndTestCase(USEasternDSTStartTestCase):
|
||||
# In 1915, Warsaw changed from Warsaw to Central European time.
|
||||
# This involved the clocks being set backwards, causing a end-of-DST
|
||||
# like situation without DST being involved.
|
||||
tzinfo = pytz.timezone('Europe/Warsaw')
|
||||
transition_time = datetime(1915, 8, 4, 22, 36, 0, tzinfo=UTC)
|
||||
before = {
|
||||
'tzname': 'WMT',
|
||||
'utcoffset': timedelta(hours=1, minutes=24),
|
||||
'dst': timedelta(0),
|
||||
}
|
||||
after = {
|
||||
'tzname': 'CET',
|
||||
'utcoffset': timedelta(hours=1),
|
||||
'dst': timedelta(0),
|
||||
}
|
||||
|
||||
|
||||
class VilniusWMTEndTestCase(USEasternDSTStartTestCase):
|
||||
# At the end of 1916, Vilnius changed timezones putting its clock
|
||||
# forward by 11 minutes 35 seconds. Neither timezone was in DST mode.
|
||||
tzinfo = pytz.timezone('Europe/Vilnius')
|
||||
instant = timedelta(seconds=31)
|
||||
transition_time = datetime(1916, 12, 31, 22, 36, 00, tzinfo=UTC)
|
||||
before = {
|
||||
'tzname': 'WMT',
|
||||
'utcoffset': timedelta(hours=1, minutes=24),
|
||||
'dst': timedelta(0),
|
||||
}
|
||||
after = {
|
||||
'tzname': 'KMT',
|
||||
'utcoffset': timedelta(hours=1, minutes=36), # Really 1:35:36
|
||||
'dst': timedelta(0),
|
||||
}
|
||||
|
||||
|
||||
class VilniusCESTStartTestCase(USEasternDSTStartTestCase):
|
||||
# In 1941, Vilnius changed from MSG to CEST, switching to summer
|
||||
# time while simultaneously reducing its UTC offset by two hours,
|
||||
# causing the clocks to go backwards for this summer time
|
||||
# switchover.
|
||||
tzinfo = pytz.timezone('Europe/Vilnius')
|
||||
transition_time = datetime(1941, 6, 23, 21, 00, 00, tzinfo=UTC)
|
||||
before = {
|
||||
'tzname': 'MSK',
|
||||
'utcoffset': timedelta(hours=3),
|
||||
'dst': timedelta(0),
|
||||
}
|
||||
after = {
|
||||
'tzname': 'CEST',
|
||||
'utcoffset': timedelta(hours=2),
|
||||
'dst': timedelta(hours=1),
|
||||
}
|
||||
|
||||
|
||||
class LondonHistoryStartTestCase(USEasternDSTStartTestCase):
|
||||
# The first known timezone transition in London was in 1847 when
|
||||
# clocks where synchronized to GMT. However, we currently only
|
||||
# understand v1 format tzfile(5) files which does handle years
|
||||
# this far in the past, so our earliest known transition is in
|
||||
# 1916.
|
||||
tzinfo = pytz.timezone('Europe/London')
|
||||
# transition_time = datetime(1847, 12, 1, 1, 15, 00, tzinfo=UTC)
|
||||
# before = {
|
||||
# 'tzname': 'LMT',
|
||||
# 'utcoffset': timedelta(minutes=-75),
|
||||
# 'dst': timedelta(0),
|
||||
# }
|
||||
# after = {
|
||||
# 'tzname': 'GMT',
|
||||
# 'utcoffset': timedelta(0),
|
||||
# 'dst': timedelta(0),
|
||||
# }
|
||||
transition_time = datetime(1916, 5, 21, 2, 00, 00, tzinfo=UTC)
|
||||
before = {
|
||||
'tzname': 'GMT',
|
||||
'utcoffset': timedelta(0),
|
||||
'dst': timedelta(0),
|
||||
}
|
||||
after = {
|
||||
'tzname': 'BST',
|
||||
'utcoffset': timedelta(hours=1),
|
||||
'dst': timedelta(hours=1),
|
||||
}
|
||||
|
||||
|
||||
class LondonHistoryEndTestCase(USEasternDSTStartTestCase):
|
||||
# Timezone switchovers are projected into the future, even
|
||||
# though no official statements exist or could be believed even
|
||||
# if they did exist. We currently only check the last known
|
||||
# transition in 2037, as we are still using v1 format tzfile(5)
|
||||
# files.
|
||||
tzinfo = pytz.timezone('Europe/London')
|
||||
# transition_time = datetime(2499, 10, 25, 1, 0, 0, tzinfo=UTC)
|
||||
transition_time = datetime(2037, 10, 25, 1, 0, 0, tzinfo=UTC)
|
||||
before = {
|
||||
'tzname': 'BST',
|
||||
'utcoffset': timedelta(hours=1),
|
||||
'dst': timedelta(hours=1),
|
||||
}
|
||||
after = {
|
||||
'tzname': 'GMT',
|
||||
'utcoffset': timedelta(0),
|
||||
'dst': timedelta(0),
|
||||
}
|
||||
|
||||
|
||||
class NoumeaHistoryStartTestCase(USEasternDSTStartTestCase):
|
||||
# Noumea adopted a whole hour offset in 1912. Previously
|
||||
# it was 11 hours, 5 minutes and 48 seconds off UTC. However,
|
||||
# due to limitations of the Python datetime library, we need
|
||||
# to round that to 11 hours 6 minutes.
|
||||
tzinfo = pytz.timezone('Pacific/Noumea')
|
||||
transition_time = datetime(1912, 1, 12, 12, 54, 12, tzinfo=UTC)
|
||||
before = {
|
||||
'tzname': 'LMT',
|
||||
'utcoffset': timedelta(hours=11, minutes=6),
|
||||
'dst': timedelta(0),
|
||||
}
|
||||
after = {
|
||||
'tzname': '+11', # pre-2017a, NCT
|
||||
'utcoffset': timedelta(hours=11),
|
||||
'dst': timedelta(0),
|
||||
}
|
||||
|
||||
|
||||
class NoumeaDSTEndTestCase(USEasternDSTStartTestCase):
|
||||
# Noumea dropped DST in 1997.
|
||||
tzinfo = pytz.timezone('Pacific/Noumea')
|
||||
transition_time = datetime(1997, 3, 1, 15, 00, 00, tzinfo=UTC)
|
||||
before = {
|
||||
'tzname': '+12', # pre-2017a, NCST
|
||||
'utcoffset': timedelta(hours=12),
|
||||
'dst': timedelta(hours=1),
|
||||
}
|
||||
after = {
|
||||
'tzname': '+11', # pre-2017a, NCT
|
||||
'utcoffset': timedelta(hours=11),
|
||||
'dst': timedelta(0),
|
||||
}
|
||||
|
||||
|
||||
class NoumeaNoMoreDSTTestCase(NoumeaDSTEndTestCase):
|
||||
# Noumea dropped DST in 1997. Here we test that it stops occuring.
|
||||
transition_time = (
|
||||
NoumeaDSTEndTestCase.transition_time + timedelta(days=365*10))
|
||||
before = NoumeaDSTEndTestCase.after
|
||||
after = NoumeaDSTEndTestCase.after
|
||||
|
||||
|
||||
class TahitiTestCase(USEasternDSTStartTestCase):
|
||||
# Tahiti has had a single transition in its history.
|
||||
tzinfo = pytz.timezone('Pacific/Tahiti')
|
||||
transition_time = datetime(1912, 10, 1, 9, 58, 16, tzinfo=UTC)
|
||||
before = {
|
||||
'tzname': 'LMT',
|
||||
'utcoffset': timedelta(hours=-9, minutes=-58),
|
||||
'dst': timedelta(0),
|
||||
}
|
||||
after = {
|
||||
'tzname': '-10', # pre-2017a, TAHT
|
||||
'utcoffset': timedelta(hours=-10),
|
||||
'dst': timedelta(0),
|
||||
}
|
||||
|
||||
|
||||
class SamoaInternationalDateLineChange(USEasternDSTStartTestCase):
|
||||
# At the end of 2011, Samoa will switch from being east of the
|
||||
# international dateline to the west. There will be no Dec 30th
|
||||
# 2011 and it will switch from UTC-10 to UTC+14.
|
||||
tzinfo = pytz.timezone('Pacific/Apia')
|
||||
transition_time = datetime(2011, 12, 30, 10, 0, 0, tzinfo=UTC)
|
||||
before = {
|
||||
'tzname': '-10', # pre-2017a, SDT
|
||||
'utcoffset': timedelta(hours=-10),
|
||||
'dst': timedelta(hours=1),
|
||||
}
|
||||
after = {
|
||||
'tzname': '+14', # pre-2017a, WSDT
|
||||
'utcoffset': timedelta(hours=14),
|
||||
'dst': timedelta(hours=1),
|
||||
}
|
||||
|
||||
|
||||
class ReferenceUSEasternDSTStartTestCase(USEasternDSTStartTestCase):
|
||||
tzinfo = reference.Eastern
|
||||
def test_arithmetic(self):
|
||||
# Reference implementation cannot handle this
|
||||
pass
|
||||
|
||||
|
||||
class ReferenceUSEasternDSTEndTestCase(USEasternDSTEndTestCase):
|
||||
tzinfo = reference.Eastern
|
||||
|
||||
def testHourBefore(self):
|
||||
# Python's datetime library has a bug, where the hour before
|
||||
# a daylight saving transition is one hour out. For example,
|
||||
# at the end of US/Eastern daylight saving time, 01:00 EST
|
||||
# occurs twice (once at 05:00 UTC and once at 06:00 UTC),
|
||||
# whereas the first should actually be 01:00 EDT.
|
||||
# Note that this bug is by design - by accepting this ambiguity
|
||||
# for one hour one hour per year, an is_dst flag on datetime.time
|
||||
# became unnecessary.
|
||||
self._test_all(
|
||||
self.transition_time - timedelta(hours=1), self.after
|
||||
)
|
||||
|
||||
def testInstantBefore(self):
|
||||
self._test_all(
|
||||
self.transition_time - timedelta(seconds=1), self.after
|
||||
)
|
||||
|
||||
def test_arithmetic(self):
|
||||
# Reference implementation cannot handle this
|
||||
pass
|
||||
|
||||
|
||||
class LocalTestCase(unittest.TestCase):
|
||||
def testLocalize(self):
|
||||
loc_tz = pytz.timezone('Europe/Amsterdam')
|
||||
|
||||
loc_time = loc_tz.localize(datetime(1930, 5, 10, 0, 0, 0))
|
||||
# Actually +00:19:32, but Python datetime rounds this
|
||||
self.assertEqual(loc_time.strftime('%Z%z'), 'AMT+0020')
|
||||
|
||||
loc_time = loc_tz.localize(datetime(1930, 5, 20, 0, 0, 0))
|
||||
# Actually +00:19:32, but Python datetime rounds this
|
||||
self.assertEqual(loc_time.strftime('%Z%z'), 'NST+0120')
|
||||
|
||||
loc_time = loc_tz.localize(datetime(1940, 5, 10, 0, 0, 0))
|
||||
# pre-2017a, abbreviation was NCT
|
||||
self.assertEqual(loc_time.strftime('%Z%z'), '+0020+0020')
|
||||
|
||||
loc_time = loc_tz.localize(datetime(1940, 5, 20, 0, 0, 0))
|
||||
self.assertEqual(loc_time.strftime('%Z%z'), 'CEST+0200')
|
||||
|
||||
loc_time = loc_tz.localize(datetime(2004, 2, 1, 0, 0, 0))
|
||||
self.assertEqual(loc_time.strftime('%Z%z'), 'CET+0100')
|
||||
|
||||
loc_time = loc_tz.localize(datetime(2004, 4, 1, 0, 0, 0))
|
||||
self.assertEqual(loc_time.strftime('%Z%z'), 'CEST+0200')
|
||||
|
||||
tz = pytz.timezone('Europe/Amsterdam')
|
||||
loc_time = loc_tz.localize(datetime(1943, 3, 29, 1, 59, 59))
|
||||
self.assertEqual(loc_time.strftime('%Z%z'), 'CET+0100')
|
||||
|
||||
|
||||
# Switch to US
|
||||
loc_tz = pytz.timezone('US/Eastern')
|
||||
|
||||
# End of DST ambiguity check
|
||||
loc_time = loc_tz.localize(datetime(1918, 10, 27, 1, 59, 59), is_dst=1)
|
||||
self.assertEqual(loc_time.strftime('%Z%z'), 'EDT-0400')
|
||||
|
||||
loc_time = loc_tz.localize(datetime(1918, 10, 27, 1, 59, 59), is_dst=0)
|
||||
self.assertEqual(loc_time.strftime('%Z%z'), 'EST-0500')
|
||||
|
||||
self.assertRaises(pytz.AmbiguousTimeError,
|
||||
loc_tz.localize, datetime(1918, 10, 27, 1, 59, 59), is_dst=None
|
||||
)
|
||||
|
||||
# Start of DST non-existent times
|
||||
loc_time = loc_tz.localize(datetime(1918, 3, 31, 2, 0, 0), is_dst=0)
|
||||
self.assertEqual(loc_time.strftime('%Z%z'), 'EST-0500')
|
||||
|
||||
loc_time = loc_tz.localize(datetime(1918, 3, 31, 2, 0, 0), is_dst=1)
|
||||
self.assertEqual(loc_time.strftime('%Z%z'), 'EDT-0400')
|
||||
|
||||
self.assertRaises(pytz.NonExistentTimeError,
|
||||
loc_tz.localize, datetime(1918, 3, 31, 2, 0, 0), is_dst=None
|
||||
)
|
||||
|
||||
# Weird changes - war time and peace time both is_dst==True
|
||||
|
||||
loc_time = loc_tz.localize(datetime(1942, 2, 9, 3, 0, 0))
|
||||
self.assertEqual(loc_time.strftime('%Z%z'), 'EWT-0400')
|
||||
|
||||
loc_time = loc_tz.localize(datetime(1945, 8, 14, 19, 0, 0))
|
||||
self.assertEqual(loc_time.strftime('%Z%z'), 'EPT-0400')
|
||||
|
||||
loc_time = loc_tz.localize(datetime(1945, 9, 30, 1, 0, 0), is_dst=1)
|
||||
self.assertEqual(loc_time.strftime('%Z%z'), 'EPT-0400')
|
||||
|
||||
loc_time = loc_tz.localize(datetime(1945, 9, 30, 1, 0, 0), is_dst=0)
|
||||
self.assertEqual(loc_time.strftime('%Z%z'), 'EST-0500')
|
||||
|
||||
# Weird changes - ambiguous time (end-of-DST like) but is_dst==False
|
||||
for zonename, ambiguous_naive, expected in [
|
||||
('Europe/Warsaw', datetime(1915, 8, 4, 23, 59, 59),
|
||||
['1915-08-04 23:59:59 WMT+0124',
|
||||
'1915-08-04 23:59:59 CET+0100']),
|
||||
('Europe/Moscow', datetime(2014, 10, 26, 1, 30),
|
||||
['2014-10-26 01:30:00 MSK+0400',
|
||||
'2014-10-26 01:30:00 MSK+0300'])]:
|
||||
loc_tz = pytz.timezone(zonename)
|
||||
self.assertRaises(pytz.AmbiguousTimeError,
|
||||
loc_tz.localize, ambiguous_naive, is_dst=None
|
||||
)
|
||||
# Also test non-boolean is_dst in the weird case
|
||||
for dst in [True, timedelta(1), False, timedelta(0)]:
|
||||
loc_time = loc_tz.localize(ambiguous_naive, is_dst=dst)
|
||||
self.assertEqual(loc_time.strftime(fmt), expected[not dst])
|
||||
|
||||
def testNormalize(self):
|
||||
tz = pytz.timezone('US/Eastern')
|
||||
dt = datetime(2004, 4, 4, 7, 0, 0, tzinfo=UTC).astimezone(tz)
|
||||
dt2 = dt - timedelta(minutes=10)
|
||||
self.assertEqual(
|
||||
dt2.strftime('%Y-%m-%d %H:%M:%S %Z%z'),
|
||||
'2004-04-04 02:50:00 EDT-0400'
|
||||
)
|
||||
|
||||
dt2 = tz.normalize(dt2)
|
||||
self.assertEqual(
|
||||
dt2.strftime('%Y-%m-%d %H:%M:%S %Z%z'),
|
||||
'2004-04-04 01:50:00 EST-0500'
|
||||
)
|
||||
|
||||
def testPartialMinuteOffsets(self):
|
||||
# utcoffset in Amsterdam was not a whole minute until 1937
|
||||
# However, we fudge this by rounding them, as the Python
|
||||
# datetime library
|
||||
tz = pytz.timezone('Europe/Amsterdam')
|
||||
utc_dt = datetime(1914, 1, 1, 13, 40, 28, tzinfo=UTC) # correct
|
||||
utc_dt = utc_dt.replace(second=0) # But we need to fudge it
|
||||
loc_dt = utc_dt.astimezone(tz)
|
||||
self.assertEqual(
|
||||
loc_dt.strftime('%Y-%m-%d %H:%M:%S %Z%z'),
|
||||
'1914-01-01 14:00:00 AMT+0020'
|
||||
)
|
||||
|
||||
# And get back...
|
||||
utc_dt = loc_dt.astimezone(UTC)
|
||||
self.assertEqual(
|
||||
utc_dt.strftime('%Y-%m-%d %H:%M:%S %Z%z'),
|
||||
'1914-01-01 13:40:00 UTC+0000'
|
||||
)
|
||||
|
||||
def no_testCreateLocaltime(self):
|
||||
# It would be nice if this worked, but it doesn't.
|
||||
tz = pytz.timezone('Europe/Amsterdam')
|
||||
dt = datetime(2004, 10, 31, 2, 0, 0, tzinfo=tz)
|
||||
self.assertEqual(
|
||||
dt.strftime(fmt),
|
||||
'2004-10-31 02:00:00 CET+0100'
|
||||
)
|
||||
|
||||
|
||||
class CommonTimezonesTestCase(unittest.TestCase):
|
||||
def test_bratislava(self):
|
||||
# Bratislava is the default timezone for Slovakia, but our
|
||||
# heuristics where not adding it to common_timezones. Ideally,
|
||||
# common_timezones should be populated from zone.tab at runtime,
|
||||
# but I'm hesitant to pay the startup cost as loading the list
|
||||
# on demand whilst remaining backwards compatible seems
|
||||
# difficult.
|
||||
self.assertTrue('Europe/Bratislava' in pytz.common_timezones)
|
||||
self.assertTrue('Europe/Bratislava' in pytz.common_timezones_set)
|
||||
|
||||
def test_us_eastern(self):
|
||||
self.assertTrue('US/Eastern' in pytz.common_timezones)
|
||||
self.assertTrue('US/Eastern' in pytz.common_timezones_set)
|
||||
|
||||
def test_belfast(self):
|
||||
# Belfast uses London time.
|
||||
self.assertTrue('Europe/Belfast' in pytz.all_timezones_set)
|
||||
self.assertFalse('Europe/Belfast' in pytz.common_timezones)
|
||||
self.assertFalse('Europe/Belfast' in pytz.common_timezones_set)
|
||||
|
||||
|
||||
class BaseTzInfoTestCase:
|
||||
'''Ensure UTC, StaticTzInfo and DstTzInfo work consistently.
|
||||
|
||||
These tests are run for each type of tzinfo.
|
||||
'''
|
||||
tz = None # override
|
||||
tz_class = None # override
|
||||
|
||||
def test_expectedclass(self):
|
||||
self.assertTrue(isinstance(self.tz, self.tz_class))
|
||||
|
||||
def test_fromutc(self):
|
||||
# naive datetime.
|
||||
dt1 = datetime(2011, 10, 31)
|
||||
|
||||
# localized datetime, same timezone.
|
||||
dt2 = self.tz.localize(dt1)
|
||||
|
||||
# Both should give the same results. Note that the standard
|
||||
# Python tzinfo.fromutc() only supports the second.
|
||||
for dt in [dt1, dt2]:
|
||||
loc_dt = self.tz.fromutc(dt)
|
||||
loc_dt2 = pytz.utc.localize(dt1).astimezone(self.tz)
|
||||
self.assertEqual(loc_dt, loc_dt2)
|
||||
|
||||
# localized datetime, different timezone.
|
||||
new_tz = pytz.timezone('Europe/Paris')
|
||||
self.assertTrue(self.tz is not new_tz)
|
||||
dt3 = new_tz.localize(dt1)
|
||||
self.assertRaises(ValueError, self.tz.fromutc, dt3)
|
||||
|
||||
def test_normalize(self):
|
||||
other_tz = pytz.timezone('Europe/Paris')
|
||||
self.assertTrue(self.tz is not other_tz)
|
||||
|
||||
dt = datetime(2012, 3, 26, 12, 0)
|
||||
other_dt = other_tz.localize(dt)
|
||||
|
||||
local_dt = self.tz.normalize(other_dt)
|
||||
|
||||
self.assertTrue(local_dt.tzinfo is not other_dt.tzinfo)
|
||||
self.assertNotEqual(
|
||||
local_dt.replace(tzinfo=None), other_dt.replace(tzinfo=None))
|
||||
|
||||
def test_astimezone(self):
|
||||
other_tz = pytz.timezone('Europe/Paris')
|
||||
self.assertTrue(self.tz is not other_tz)
|
||||
|
||||
dt = datetime(2012, 3, 26, 12, 0)
|
||||
other_dt = other_tz.localize(dt)
|
||||
|
||||
local_dt = other_dt.astimezone(self.tz)
|
||||
|
||||
self.assertTrue(local_dt.tzinfo is not other_dt.tzinfo)
|
||||
self.assertNotEqual(
|
||||
local_dt.replace(tzinfo=None), other_dt.replace(tzinfo=None))
|
||||
|
||||
|
||||
class OptimizedUTCTestCase(unittest.TestCase, BaseTzInfoTestCase):
|
||||
tz = pytz.utc
|
||||
tz_class = tz.__class__
|
||||
|
||||
|
||||
class LegacyUTCTestCase(unittest.TestCase, BaseTzInfoTestCase):
|
||||
# Deprecated timezone, but useful for comparison tests.
|
||||
tz = pytz.timezone('Etc/UTC')
|
||||
tz_class = StaticTzInfo
|
||||
|
||||
|
||||
class StaticTzInfoTestCase(unittest.TestCase, BaseTzInfoTestCase):
|
||||
tz = pytz.timezone('GMT')
|
||||
tz_class = StaticTzInfo
|
||||
|
||||
|
||||
class DstTzInfoTestCase(unittest.TestCase, BaseTzInfoTestCase):
|
||||
tz = pytz.timezone('Australia/Melbourne')
|
||||
tz_class = DstTzInfo
|
||||
|
||||
|
||||
def test_suite():
|
||||
suite = unittest.TestSuite()
|
||||
suite.addTest(doctest.DocTestSuite('pytz'))
|
||||
suite.addTest(doctest.DocTestSuite('pytz.tzinfo'))
|
||||
import test_tzinfo
|
||||
suite.addTest(unittest.defaultTestLoader.loadTestsFromModule(test_tzinfo))
|
||||
return suite
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
warnings.simplefilter("error") # Warnings should be fatal in tests.
|
||||
unittest.main(defaultTest='test_suite')
|
|
@ -3,38 +3,37 @@
|
|||
$Id: tzfile.py,v 1.8 2004/06/03 00:15:24 zenzen Exp $
|
||||
'''
|
||||
|
||||
try:
|
||||
from cStringIO import StringIO
|
||||
except ImportError:
|
||||
from io import StringIO
|
||||
from datetime import datetime, timedelta
|
||||
from datetime import datetime
|
||||
from struct import unpack, calcsize
|
||||
|
||||
from pytz.tzinfo import StaticTzInfo, DstTzInfo, memorized_ttinfo
|
||||
from pytz.tzinfo import memorized_datetime, memorized_timedelta
|
||||
|
||||
|
||||
def _byte_string(s):
|
||||
"""Cast a string or byte string to an ASCII byte string."""
|
||||
return s.encode('ASCII')
|
||||
|
||||
_NULL = _byte_string('\0')
|
||||
|
||||
|
||||
def _std_string(s):
|
||||
"""Cast a string or byte string to an ASCII string."""
|
||||
return str(s.decode('ASCII'))
|
||||
|
||||
|
||||
def build_tzinfo(zone, fp):
|
||||
head_fmt = '>4s c 15x 6l'
|
||||
head_size = calcsize(head_fmt)
|
||||
(magic, format, ttisgmtcnt, ttisstdcnt,leapcnt, timecnt,
|
||||
typecnt, charcnt) = unpack(head_fmt, fp.read(head_size))
|
||||
(magic, format, ttisgmtcnt, ttisstdcnt, leapcnt, timecnt,
|
||||
typecnt, charcnt) = unpack(head_fmt, fp.read(head_size))
|
||||
|
||||
# Make sure it is a tzfile(5) file
|
||||
assert magic == _byte_string('TZif'), 'Got magic %s' % repr(magic)
|
||||
|
||||
# Read out the transition times, localtime indices and ttinfo structures.
|
||||
data_fmt = '>%(timecnt)dl %(timecnt)dB %(ttinfo)s %(charcnt)ds' % dict(
|
||||
timecnt=timecnt, ttinfo='lBB'*typecnt, charcnt=charcnt)
|
||||
timecnt=timecnt, ttinfo='lBB' * typecnt, charcnt=charcnt)
|
||||
data_size = calcsize(data_fmt)
|
||||
data = unpack(data_fmt, fp.read(data_size))
|
||||
|
||||
|
@ -53,7 +52,7 @@ def build_tzinfo(zone, fp):
|
|||
i = 0
|
||||
while i < len(ttinfo_raw):
|
||||
# have we looked up this timezone name yet?
|
||||
tzname_offset = ttinfo_raw[i+2]
|
||||
tzname_offset = ttinfo_raw[i + 2]
|
||||
if tzname_offset not in tznames:
|
||||
nul = tznames_raw.find(_NULL, tzname_offset)
|
||||
if nul < 0:
|
||||
|
@ -61,12 +60,12 @@ def build_tzinfo(zone, fp):
|
|||
tznames[tzname_offset] = _std_string(
|
||||
tznames_raw[tzname_offset:nul])
|
||||
ttinfo.append((ttinfo_raw[i],
|
||||
bool(ttinfo_raw[i+1]),
|
||||
bool(ttinfo_raw[i + 1]),
|
||||
tznames[tzname_offset]))
|
||||
i += 3
|
||||
|
||||
# Now build the timezone object
|
||||
if len(ttinfo) ==1 or len(transitions) == 0:
|
||||
if len(ttinfo) == 1 or len(transitions) == 0:
|
||||
ttinfo[0][0], ttinfo[0][2]
|
||||
cls = type(zone, (StaticTzInfo,), dict(
|
||||
zone=zone,
|
||||
|
@ -91,21 +90,21 @@ def build_tzinfo(zone, fp):
|
|||
if not inf[1]:
|
||||
dst = 0
|
||||
else:
|
||||
for j in range(i-1, -1, -1):
|
||||
for j in range(i - 1, -1, -1):
|
||||
prev_inf = ttinfo[lindexes[j]]
|
||||
if not prev_inf[1]:
|
||||
break
|
||||
dst = inf[0] - prev_inf[0] # dst offset
|
||||
dst = inf[0] - prev_inf[0] # dst offset
|
||||
|
||||
# Bad dst? Look further. DST > 24 hours happens when
|
||||
# a timzone has moved across the international dateline.
|
||||
if dst <= 0 or dst > 3600*3:
|
||||
for j in range(i+1, len(transitions)):
|
||||
if dst <= 0 or dst > 3600 * 3:
|
||||
for j in range(i + 1, len(transitions)):
|
||||
stdinf = ttinfo[lindexes[j]]
|
||||
if not stdinf[1]:
|
||||
dst = inf[0] - stdinf[0]
|
||||
if dst > 0:
|
||||
break # Found a useful std time.
|
||||
break # Found a useful std time.
|
||||
|
||||
tzname = inf[2]
|
||||
|
||||
|
@ -129,9 +128,7 @@ if __name__ == '__main__':
|
|||
from pprint import pprint
|
||||
base = os.path.join(os.path.dirname(__file__), 'zoneinfo')
|
||||
tz = build_tzinfo('Australia/Melbourne',
|
||||
open(os.path.join(base,'Australia','Melbourne'), 'rb'))
|
||||
open(os.path.join(base, 'Australia', 'Melbourne'), 'rb'))
|
||||
tz = build_tzinfo('US/Eastern',
|
||||
open(os.path.join(base,'US','Eastern'), 'rb'))
|
||||
open(os.path.join(base, 'US', 'Eastern'), 'rb'))
|
||||
pprint(tz._utc_transition_times)
|
||||
#print tz.asPython(4)
|
||||
#print tz.transitions_mapping
|
||||
|
|
|
@ -13,6 +13,8 @@ from pytz.exceptions import AmbiguousTimeError, NonExistentTimeError
|
|||
__all__ = []
|
||||
|
||||
_timedelta_cache = {}
|
||||
|
||||
|
||||
def memorized_timedelta(seconds):
|
||||
'''Create only one instance of each distinct timedelta'''
|
||||
try:
|
||||
|
@ -24,6 +26,8 @@ def memorized_timedelta(seconds):
|
|||
|
||||
_epoch = datetime.utcfromtimestamp(0)
|
||||
_datetime_cache = {0: _epoch}
|
||||
|
||||
|
||||
def memorized_datetime(seconds):
|
||||
'''Create only one instance of each distinct datetime'''
|
||||
try:
|
||||
|
@ -36,21 +40,24 @@ def memorized_datetime(seconds):
|
|||
return dt
|
||||
|
||||
_ttinfo_cache = {}
|
||||
|
||||
|
||||
def memorized_ttinfo(*args):
|
||||
'''Create only one instance of each distinct tuple'''
|
||||
try:
|
||||
return _ttinfo_cache[args]
|
||||
except KeyError:
|
||||
ttinfo = (
|
||||
memorized_timedelta(args[0]),
|
||||
memorized_timedelta(args[1]),
|
||||
args[2]
|
||||
)
|
||||
memorized_timedelta(args[0]),
|
||||
memorized_timedelta(args[1]),
|
||||
args[2]
|
||||
)
|
||||
_ttinfo_cache[args] = ttinfo
|
||||
return ttinfo
|
||||
|
||||
_notime = memorized_timedelta(0)
|
||||
|
||||
|
||||
def _to_seconds(td):
|
||||
'''Convert a timedelta to seconds'''
|
||||
return td.seconds + td.days * 24 * 60 * 60
|
||||
|
@ -154,14 +161,20 @@ class DstTzInfo(BaseTzInfo):
|
|||
timezone definition.
|
||||
'''
|
||||
# Overridden in subclass
|
||||
_utc_transition_times = None # Sorted list of DST transition times in UTC
|
||||
_transition_info = None # [(utcoffset, dstoffset, tzname)] corresponding
|
||||
# to _utc_transition_times entries
|
||||
|
||||
# Sorted list of DST transition times, UTC
|
||||
_utc_transition_times = None
|
||||
|
||||
# [(utcoffset, dstoffset, tzname)] corresponding to
|
||||
# _utc_transition_times entries
|
||||
_transition_info = None
|
||||
|
||||
zone = None
|
||||
|
||||
# Set in __init__
|
||||
|
||||
_tzinfos = None
|
||||
_dst = None # DST offset
|
||||
_dst = None # DST offset
|
||||
|
||||
def __init__(self, _inf=None, _tzinfos=None):
|
||||
if _inf:
|
||||
|
@ -170,7 +183,8 @@ class DstTzInfo(BaseTzInfo):
|
|||
else:
|
||||
_tzinfos = {}
|
||||
self._tzinfos = _tzinfos
|
||||
self._utcoffset, self._dst, self._tzname = self._transition_info[0]
|
||||
self._utcoffset, self._dst, self._tzname = (
|
||||
self._transition_info[0])
|
||||
_tzinfos[self._transition_info[0]] = self
|
||||
for inf in self._transition_info[1:]:
|
||||
if inf not in _tzinfos:
|
||||
|
@ -178,8 +192,8 @@ class DstTzInfo(BaseTzInfo):
|
|||
|
||||
def fromutc(self, dt):
|
||||
'''See datetime.tzinfo.fromutc'''
|
||||
if (dt.tzinfo is not None
|
||||
and getattr(dt.tzinfo, '_tzinfos', None) is not self._tzinfos):
|
||||
if (dt.tzinfo is not None and
|
||||
getattr(dt.tzinfo, '_tzinfos', None) is not self._tzinfos):
|
||||
raise ValueError('fromutc: dt.tzinfo is not self')
|
||||
dt = dt.replace(tzinfo=None)
|
||||
idx = max(0, bisect_right(self._utc_transition_times, dt) - 1)
|
||||
|
@ -337,8 +351,8 @@ class DstTzInfo(BaseTzInfo):
|
|||
# obtain the correct timezone by winding the clock back.
|
||||
else:
|
||||
return self.localize(
|
||||
dt - timedelta(hours=6), is_dst=False) + timedelta(hours=6)
|
||||
|
||||
dt - timedelta(hours=6),
|
||||
is_dst=False) + timedelta(hours=6)
|
||||
|
||||
# If we get this far, we have multiple possible timezones - this
|
||||
# is an ambiguous case occuring during the end-of-DST transition.
|
||||
|
@ -351,9 +365,8 @@ class DstTzInfo(BaseTzInfo):
|
|||
# Filter out the possiblilities that don't match the requested
|
||||
# is_dst
|
||||
filtered_possible_loc_dt = [
|
||||
p for p in possible_loc_dt
|
||||
if bool(p.tzinfo._dst) == is_dst
|
||||
]
|
||||
p for p in possible_loc_dt if bool(p.tzinfo._dst) == is_dst
|
||||
]
|
||||
|
||||
# Hopefully we only have one possibility left. Return it.
|
||||
if len(filtered_possible_loc_dt) == 1:
|
||||
|
@ -372,9 +385,10 @@ class DstTzInfo(BaseTzInfo):
|
|||
# Choose the earliest (by UTC) applicable timezone if is_dst=True
|
||||
# Choose the latest (by UTC) applicable timezone if is_dst=False
|
||||
# i.e., behave like end-of-DST transition
|
||||
dates = {} # utc -> local
|
||||
dates = {} # utc -> local
|
||||
for local_dt in filtered_possible_loc_dt:
|
||||
utc_time = local_dt.replace(tzinfo=None) - local_dt.tzinfo._utcoffset
|
||||
utc_time = (
|
||||
local_dt.replace(tzinfo=None) - local_dt.tzinfo._utcoffset)
|
||||
assert utc_time not in dates
|
||||
dates[utc_time] = local_dt
|
||||
return dates[[min, max][not is_dst](dates)]
|
||||
|
@ -389,11 +403,11 @@ class DstTzInfo(BaseTzInfo):
|
|||
>>> tz = timezone('America/St_Johns')
|
||||
>>> ambiguous = datetime(2009, 10, 31, 23, 30)
|
||||
|
||||
>>> tz.utcoffset(ambiguous, is_dst=False)
|
||||
datetime.timedelta(-1, 73800)
|
||||
>>> str(tz.utcoffset(ambiguous, is_dst=False))
|
||||
'-1 day, 20:30:00'
|
||||
|
||||
>>> tz.utcoffset(ambiguous, is_dst=True)
|
||||
datetime.timedelta(-1, 77400)
|
||||
>>> str(tz.utcoffset(ambiguous, is_dst=True))
|
||||
'-1 day, 21:30:00'
|
||||
|
||||
>>> try:
|
||||
... tz.utcoffset(ambiguous)
|
||||
|
@ -421,19 +435,19 @@ class DstTzInfo(BaseTzInfo):
|
|||
|
||||
>>> normal = datetime(2009, 9, 1)
|
||||
|
||||
>>> tz.dst(normal)
|
||||
datetime.timedelta(0, 3600)
|
||||
>>> tz.dst(normal, is_dst=False)
|
||||
datetime.timedelta(0, 3600)
|
||||
>>> tz.dst(normal, is_dst=True)
|
||||
datetime.timedelta(0, 3600)
|
||||
>>> str(tz.dst(normal))
|
||||
'1:00:00'
|
||||
>>> str(tz.dst(normal, is_dst=False))
|
||||
'1:00:00'
|
||||
>>> str(tz.dst(normal, is_dst=True))
|
||||
'1:00:00'
|
||||
|
||||
>>> ambiguous = datetime(2009, 10, 31, 23, 30)
|
||||
|
||||
>>> tz.dst(ambiguous, is_dst=False)
|
||||
datetime.timedelta(0)
|
||||
>>> tz.dst(ambiguous, is_dst=True)
|
||||
datetime.timedelta(0, 3600)
|
||||
>>> str(tz.dst(ambiguous, is_dst=False))
|
||||
'0:00:00'
|
||||
>>> str(tz.dst(ambiguous, is_dst=True))
|
||||
'1:00:00'
|
||||
>>> try:
|
||||
... tz.dst(ambiguous)
|
||||
... except AmbiguousTimeError:
|
||||
|
@ -494,23 +508,22 @@ class DstTzInfo(BaseTzInfo):
|
|||
dst = 'STD'
|
||||
if self._utcoffset > _notime:
|
||||
return '<DstTzInfo %r %s+%s %s>' % (
|
||||
self.zone, self._tzname, self._utcoffset, dst
|
||||
)
|
||||
self.zone, self._tzname, self._utcoffset, dst
|
||||
)
|
||||
else:
|
||||
return '<DstTzInfo %r %s%s %s>' % (
|
||||
self.zone, self._tzname, self._utcoffset, dst
|
||||
)
|
||||
self.zone, self._tzname, self._utcoffset, dst
|
||||
)
|
||||
|
||||
def __reduce__(self):
|
||||
# Special pickle to zone remains a singleton and to cope with
|
||||
# database changes.
|
||||
return pytz._p, (
|
||||
self.zone,
|
||||
_to_seconds(self._utcoffset),
|
||||
_to_seconds(self._dst),
|
||||
self._tzname
|
||||
)
|
||||
|
||||
self.zone,
|
||||
_to_seconds(self._utcoffset),
|
||||
_to_seconds(self._dst),
|
||||
self._tzname
|
||||
)
|
||||
|
||||
|
||||
def unpickler(zone, utcoffset=None, dstoffset=None, tzname=None):
|
||||
|
@ -549,8 +562,8 @@ def unpickler(zone, utcoffset=None, dstoffset=None, tzname=None):
|
|||
# get changed from the initial guess by the database maintainers to
|
||||
# match reality when this information is discovered.
|
||||
for localized_tz in tz._tzinfos.values():
|
||||
if (localized_tz._utcoffset == utcoffset
|
||||
and localized_tz._dst == dstoffset):
|
||||
if (localized_tz._utcoffset == utcoffset and
|
||||
localized_tz._dst == dstoffset):
|
||||
return localized_tz
|
||||
|
||||
# This (utcoffset, dstoffset) information has been removed from the
|
||||
|
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue