# Author Mikko Kortelainen <mikko.kortelainen@techelp.fi>
# Tested (not very thoroughly) with Python 2.7.5 and 3.3.2, and SQLAlchemy 0.8.2

import sqlalchemy
from sqlalchemy import create_engine, MetaData, Table, Sequence
from sqlalchemy.orm import sessionmaker
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.engine import reflection

class SQLReflector(object):
    """Reflects SQLAlchemy declarative classes from database.

    Requires the SQLAlchemy database engine as argument. Tested only with
    PostgreSQL. For example:

    from sqlalchemy import create_engine
    connstr = "postgresql+psycopg2://user:pass@host/database?sslmode=require"
    engine = create_engine(connstr)
    reflector = SQLReflector(engine)
    session = reflector.session

    MyTable = reflector.reflect("my_schema", "my_table")
    a_row = session.query(MyTable).filter(MyTable.id == 1).one()

    After a table has been reflected, it is saved in the SQLReflector object.
    You can find them under the classes property:

    reflector.classes.<schemaname>.<tablename>

    By default schema and table names are sanitized so that characters which
    are not letters, numbers or underscores will be stripped from the name.
    Also, initial numeric characters will be stripped. What's left is treated
    as the identifier. For table names (not schema names), the default is
    also to convert the resulting name into CamelCase by first converting
    it to lowercase and after that removing all underscore characters and
    converting the character following the underscore to uppercase. For
    example:

      Table "my_data" becomes class "MyData"
      Table "MyData" becomes class "Mydata"

    You can turn CamelCase conversion off by giving camelcase=False.

    If you give sanitize_names=False, the names are not sanitized at all
    before making Python identifiers out of them. So your schemas and tables
    should have names that are compliant already.
    """

    def __init__(self, database_engine, sanitize_names=True, camelcase=True):
        self.engine = database_engine
        self.sessionmaker = sessionmaker(bind=self.engine)
        self.session = self.sessionmaker()
        self.metadata = MetaData(bind=self.engine)
        self.inspector = reflection.Inspector.from_engine(self.engine)

        self.classes = self.make_empty_object()
        self.sanitize_names = sanitize_names
        self.camelcase = camelcase

    def __repr__(self):
        """Formal representation"""
        return """<SQLReflector sanitize_names=%s camelcase=%s>""" % (
            "True" if self.sanitize_names else "False",
            "True" if self.camelcase else "False")


    def get_schema_names(self):
        """Returns a list of available schemas in database"""
        return self.inspector.get_schema_names()


    def get_table_names(self, schema_name):
        """Returns a list of available tables in a schema.

        For more in-depth exploration you can use the SQLAlchemy reflection.Inspector
        instance inside the class. The member is called "inspector". """
        return self.inspector.get_table_names(schema=schema_name)


    def reflect_table(self, schema_name, table_name, column_definitions=None, \
                            sanitize_names=None, camelcase=None):
        """Reflects a table's metadata and creates an SQLAlchemy declarative
        class with a nice sanitized CamelCase name out of it.

        You may give extra column definitions or pretty much anything you want
        to add to the class in the column_definitions argument. For example to
        define an integer primary key to a PostgreSQL view:

        from sqlalchemy import Column, Integer
        MyView  = reflector.reflect_table(
                      'my_schema',
                      'my_view',
                      column_definitions=
                      (
                          Column('id_col', Integer, primary_key=True),
                      )
                  )

        Or you can define an explicit foreign key if such is not in place in the
        database:

        from sqlalchemy import Column, Integer, ForeignKey
        MyTable = reflector.reflect_class(
                      'my_schema',
                      'my_table',
                      column_definitions=
                      (
                          Column(
                              'fk_col',
                              Integer,
                              ForeignKey('foreign_schema.foreign_table.foreign_column')
                          )
                      )
                  )

        """

        # Sanitize names if requested
        if camelcase is None: camelcase = self.camelcase
        if sanitize_names is None: sanitize_names = self.sanitize_names

        if sanitize_names:
            class_name = self.sanitize_tablename(table_name, camelcase)
            sane_schema_name = self.sanitize_schemaname(schema_name)
        else:
            class_name = table_name
            sane_schema_name = schema_name

        # Make a class from declarative_base()
        Base = declarative_base()
        class_bases = (Base,)

        if column_definitions is not None:
            # Add extra column_definitions argument as part of the new class
            table_definition = Table(
                table_name,
                self.metadata,
                *column_definitions,
                schema=schema_name,
                autoload=True)

        else:
            # No extra definitions
            table_definition = Table(
                table_name,
                self.metadata,
                schema=schema_name,
                autoload=True)

        # Make a proper class out of it
        the_class = type(class_name, class_bases, dict(__table__ = table_definition))

        # Save it as self.classes.<sane_schema_name>.<class_name>
        if not hasattr(self.classes, sane_schema_name):
            setattr(self.classes, sane_schema_name, self.make_empty_object())

        setattr(getattr(self.classes, sane_schema_name), class_name, the_class)

        # Also return it
        return the_class


    def reflect_tables(self, schema_name, table_name_list, \
                             sanitize_names=None, camelcase=None):
        """Makes a list of classes from a list of table names that are all
        defined in the give schema"""
        return [self.reflect_table(schema_name, x, sanitize_names, camelcase) \
                for x in table_name_list]


    def reflect_schema(self, schema_name, sanitize_names=None, camelcase=None):
        """Reflects all tables in a given schema"""
        table_names = self.inspector.get_table_names(schema=schema_name)
        return self.reflect_tables(schema_name, table_names, sanitize_names, camelcase)


    def reflect_database(self, sanitize_names=None, camelcase=None):
        """Reflects all tables in every schema of the database"""
        schema_names = self.inspector.get_schema_names()
        classes = list()
        for schema_name in schema_names:
            classes.extend(self.reflect_schema(schema_name, sanitize_names, camelcase))
        return classes


    def make_empty_object(self):
        """Create an empty object. It is a subclass of object, with the
        distinction, that this one can have its attributes manipulated with
        setattr"""
        EmptyObject = type('EmptyObject', (object,), dict())
        return EmptyObject()


    def underscore_to_camelcase(self, value):
        def camelcase():
            yield str.lower
            while True:
                yield str.capitalize

        c = camelcase()
        return "".join(next(c)(x) if x else '_' for x in value.split("_"))


    def capitalize_first_letter(self, string):
        return string[0].upper() + string[1:]


    def to_unicode(self, thing):
        """For Python 2 and 3 compatibility"""
        import sys
        if sys.version < '3':
            return unicode(thing)
        else:
            return str(thing)


    def to_string(self, thing):
        return str(thing)


    def sanitize_tablename(self, tablename, camelcase=True):
        """Returns a suitable class name derived from tablename
        - accents will be removed from letters
        - other than valid characters will be removed: [^A-Za-z0-9_]
        - underscores will be converted into CamelCase unless camelcase=False
        - initial numeric characters will be stripped
        - first character will be capitalized
        """
        from unicodedata import normalize
        from re import sub

        # Substitute accents and convert to ASCII
        normal = self.to_string(normalize("NFKD", self.to_unicode(tablename)).encode("ASCII", "ignore").decode())

        # Remove any unwanted characters
        new_normal = sub("[^A-Za-z0-9_]+", "", normal)

        # Convert underscore_identifiers to CamelCase
        if camelcase:
            camel = self.underscore_to_camelcase(new_normal)
        else:
            camel = new_normal

        # Remove initial numeric characters (class names cannot start with a
        # number)
        start_alpha = sub("^[0-9]+", "", camel)

        # Capitalize the first letter (only if CamelCase)
        if camelcase:
            classname = self.capitalize_first_letter(start_alpha)
        else:
            classname = start_alpha

        return classname


    def sanitize_schemaname(self, schemaname):
        """This sanitizes a schema name so that it can be used as a Python
        object.
        - accents will be removed from letters
        - other than valid characters will be removed: [^A-Za-z0-9_]
        - initial numeric characters will be stripped (must be A-Za-z_)
        """
        from unicodedata import normalize
        from re import sub

        # Substitute accents and convert to ASCII
        normal = self.to_string(normalize("NFKD", self.to_unicode(schemaname)).encode("ASCII", "ignore").decode())

        # Remove any unwanted characters
        new_normal = sub("[^A-Za-z0-9_]+", "", normal)

        # Remove initial numeric characters (class names cannot start with a
        # number)
        schema_name = sub("^[0-9]+", "", new_normal)

        return schema_name
