Source code for watson.db.commands

# -*- coding: utf-8 -*-
import codecs
import collections
import enum
import json
import os
import sys
from datetime import datetime

from alembic import command as alembic_command
from alembic.config import Config as AlembicConfig
from sqlacodegen import codegen
from sqlalchemy import create_engine, inspect, schema
from sqlalchemy.ext.declarative.clsregistry import _ModuleMarker

from watson.common import imports
from watson.console import ConsoleError, command
from watson.console.decorators import arg
from watson.db import engine, fixtures, session
from watson.di import ContainerAware


class Config(AlembicConfig):
    def get_template_directory(self):
        package_dir = os.path.abspath(os.path.dirname(__file__))
        return os.path.join(package_dir, 'alembic', 'templates')


class BaseDatabaseCommand(ContainerAware):
    __ioc_definition__ = {
        'init': {
            'config': lambda container: container.get('application.config')['db']
        }
    }

    def __init__(self, config):
        self.config = config


[docs]class Database(command.Base, BaseDatabaseCommand): """Database commands. """ name = 'db' @property def metadata(self): metadatas = {} for name, options in self.config['connections'].items(): metadatas[name] = self._load_metadata(options['metadata']) return metadatas @property def connections(self): connections = {} for name, options in self.config['connections'].items(): connections[name] = options['connection_string'] return connections @property def sessions(self): return self._session_or_engine('session') @property def engines(self): return self._session_or_engine('engine') def _load_metadata(self, metadata): if isinstance(metadata, str): try: return imports.load_definition_from_string(metadata) except Exception as e: raise ConsoleError( 'Missing connection metadata for {} ({})'.format( metadata, e))
[docs] def _session_or_engine(self, type_): """Retrieves all the sessions or engines from the container. """ results = {} for name in self.config['connections']: obj_name = getattr(globals()[type_], 'NAME').format(name) results[name] = self.container.get(obj_name) return results
[docs] @arg('drop', action='store_true', default=False, optional=True) def create(self, drop): """Create the relevant databases. """ engines = self.engines for database, model_base in self.metadata.items(): self.write('Creating database {}...'.format(database)) engine.create_db(engines[database], model_base, drop=drop) self.write('Created the databases.') return True
[docs] @arg() def populate(self): """Add data from fixtures to the database(s). """ if 'fixtures' not in self.config: self.write('No fixtures to add.') return False self.write('Adding fixtures...') sessions = self.sessions total = fixtures.populate_all(sessions, self.config['fixtures']) self.write('Added {} fixtures to {} database(s).'.format( total, len(sessions))) return True
def _get_models_in_session(self, name, model): results = [] if isinstance(model, _ModuleMarker): return None inst = inspect(model) attr_names = [c_attr.key for c_attr in inst.mapper.column_attrs] for obj in self.sessions[name].query(model): fields = {} new_obj = collections.OrderedDict() new_obj['class'] = imports.get_qualified_name(model) attr_names = [c_attr.key for c_attr in inst.mapper.column_attrs] for column in attr_names: value = getattr(obj, column) if isinstance(value, enum.Enum): value = value.value elif isinstance(value, datetime): value = str(value) elif isinstance(value, bytes): value = value.decode('utf-8') fields[column] = value new_obj['fields'] = collections.OrderedDict( sorted(fields.items(), key=lambda k: k[0])) results.append(new_obj) return results
[docs] @arg('models', optional=True) @arg('output_to_stdout', default=False, optional=True) def generate_fixtures(self, models, output_to_stdout): """Generate fixture data in json format. Args: models (string): A comma separated list of models to output output_to_stdout (boolean): Whether or not to output to the stdout """ if models: models = models.split(',') for name, options in self.config['connections'].items(): metadata = self._load_metadata(options['metadata']) for model in metadata._decl_class_registry.values(): model_name = imports.get_qualified_name(model) if models and model_name not in models: continue records = self._get_models_in_session(name, model) if not records: continue records = json.dumps(records, indent=4) if output_to_stdout: self.write(records) else: model_name, path = fixtures.save( model, records, self.config['fixtures']) self.write( 'Created fixture for {} at {}'.format(model_name, path))
[docs] @arg() def dump(self): """Print the Schema of the database. """ def dump_sql(sql, *multiparams, **params): self.write(str(sql)) connections = self.connections for database, model_base in self.metadata.items(): self.write('Schema for "{}" from metadata {}...'.format( database, repr(model_base))) self.write() _engine = create_engine( connections[database], strategy='mock', executor=dump_sql) model_base.metadata.create_all(_engine, checkfirst=False) return True
[docs] @arg('outfile', action='store_true', default=False, optional=True) @arg('tables', action='store_true', default=None, optional=True) @arg('connection_string', optional=True) def generate_models(self, connection_string=None, tables=None, outfile=None): """Generate models from an existing database schema. Args: connection_string (string): The database to connect to tables (string): Tables to process (comma-separated, default: all) outfile (string): File to write output to (default: stdout) """ tables = tables.split(',') if tables else None outfile = codecs.open(outfile, 'w', encoding='utf-8') if outfile else sys.stdout connections = self.connections databases = {database: connections[database] for database, _ in self.metadata.items()} if connection_string: databases = {connection_string.split('/')[-1]: connection_string} for database, connection in databases.items(): self.write('SqlAlchemy model classes for "{}"'.format(database)) self.write() _engine = create_engine(connection) metadata = schema.MetaData(_engine) metadata.reflect(_engine, only=tables) generator = codegen.CodeGenerator(metadata) generator.render(outfile)
[docs]class Migrate(command.Base, BaseDatabaseCommand): """Alembic integration with Watson. """ name = 'db:migrate' def _check_migrations(self): if 'migrations' not in self.config: raise ConsoleError( 'No migrations configuration can be found.') @property def database_names(self): names = [] for name in self.config['connections']: names.append(name) return names @property def directory(self): self._check_migrations() return os.path.abspath(self.config['migrations']['path']) @property def alembic_config_file(self): return os.path.join(self.directory, 'alembic.ini') def alembic_config(self, with_ini=True, relative_script_location=True): self._check_migrations() directory = self.config['migrations']['path'] args = [] if with_ini: args.append(self.alembic_config_file) config = Config(*args) script_location = directory if not relative_script_location: script_location = os.path.abspath(directory) config.set_main_option('script_location', script_location) config.set_main_option('databases', ', '.join(self.database_names)) config.watson = { 'config': self.config, 'container': self.container } return config
[docs] @arg() def init(self): """Initializes Alembic migrations for the project. """ config = self.alembic_config(with_ini=False) config.config_file_name = self.alembic_config_file alembic_command.init(config, config.get_main_option('script_location'), 'watson') return True
[docs] @arg() def history(self, rev_range): """List changeset scripts in chronological order. Args: rev_range: Revision range in format [start]:[end] """ config = self.alembic_config() alembic_command.history(config, rev_range) return True
[docs] @arg() def current(self): """Display the current revision for each database. """ config = self.alembic_config() alembic_command.current(config) return True
[docs] @arg('sql', action='store_true', default=False, optional=True) @arg('autogenerate', action='store_true', default=False, optional=True) @arg('message', optional=True, default='Revision') def revision(self, sql=False, autogenerate=False, message=None): """Create a new revision file. Args: sql (bool): Don't emit SQL to database - dump to standard output instead autogenerate (bool): Populate revision script with andidate migration operatons, based on comparison of database to model message (string): Message string to use with 'revision' """ config = self.alembic_config(relative_script_location=False) return alembic_command.revision( config, message, autogenerate=autogenerate, sql=sql)
[docs] @arg('sql', action='store_true', default=False, optional=True) @arg('tag', default=None, optional=True) @arg('revision', default=None) def stamp(self, sql=False, tag=None, revision='head'): """'stamp' the revision table with the given revision; don't run any migrations. Args: sql (bool): Don't emit SQL to database - dump to standard output instead tag (string): Arbitrary 'tag' name - can be used by custom env.py scripts revision (string): Revision identifier """ config = self.alembic_config() alembic_command.stamp(config, revision, tag=tag, sql=sql) return True
[docs] @arg('sql', action='store_true', default=False, optional=True) @arg('tag', default=None, optional=True) @arg('revision', default='head', nargs='?') def upgrade(self, sql=False, tag=None, revision='head'): """Upgrade to a later version. Args: sql (bool): Don't emit SQL to database - dump to standard output instead tag (string): Arbitrary 'tag' name - can be used by custom env.py scripts revision (string): Revision identifier """ config = self.alembic_config() alembic_command.upgrade(config, revision, tag=tag, sql=sql) return True
[docs] @arg('sql', action='store_true', default=False, optional=True) @arg('tag', default=None, optional=True) @arg('revision', default='-1', nargs='?') def downgrade(self, sql=False, tag=None, revision='-1'): """Revert to a previous version. """ config = self.alembic_config() alembic_command.downgrade(config, revision, tag=tag, sql=sql) return True
[docs] @arg() def branches(self): """Show current un-spliced branch points. """ config = self.alembic_config() alembic_command.branches(config) return True