Python peewee.SQL Examples

The following are 24 code examples of peewee.SQL(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module peewee , or try the search function .
Example #1
Source File: text.py    From quay with Apache License 2.0 6 votes vote down vote up
def match_mysql(field, search_query):
    """
    Generates a full-text match query using a Match operation, which is needed for MySQL.
    """
    if field.name.find("`") >= 0:  # Just to be safe.
        raise Exception("How did field name '%s' end up containing a backtick?" % field.name)

    # Note: There is a known bug in MySQL (https://bugs.mysql.com/bug.php?id=78485) that causes
    # queries of the form `*` to raise a parsing error. If found, simply filter out.
    search_query = search_query.replace("*", "")

    # Just to be absolutely sure.
    search_query = search_query.replace("'", "")
    search_query = search_query.replace('"', "")
    search_query = search_query.replace("`", "")

    return NodeList(
        (fn.MATCH(SQL("`%s`" % field.name)), fn.AGAINST(SQL("%s", [search_query]))), parens=True
    ) 
Example #2
Source File: peeweedbevolve.py    From peewee-db-evolve with GNU Lesser General Public License v3.0 6 votes vote down vote up
def _execute(db, to_run, interactive=True, commit=True):
  if interactive: print()
  try:
    with db.atomic() as txn:
      for sql, params in to_run:
        if interactive or DEBUG: print_sql(' %s; %s' % (sql, params or ''))
        if sql.strip().startswith('--'): continue
        db.execute_sql(sql, params)
      if interactive:
        print()
        print(
          (colorama.Style.BRIGHT + 'SUCCESS!' + colorama.Style.RESET_ALL) if commit else 'TEST PASSED - ROLLING BACK',
          colorama.Style.DIM + '-',
          'https://github.com/keredson/peewee-db-evolve' + colorama.Style.RESET_ALL
        )
        print()
      if not commit:
        txn.rollback()
  except Exception as e:
    print()
    print('------------------------------------------')
    print(colorama.Style.BRIGHT + colorama.Fore.RED + ' SQL EXCEPTION - ROLLING BACK ALL CHANGES' + colorama.Style.RESET_ALL)
    print('------------------------------------------')
    print()
    raise e 
Example #3
Source File: peeweedbevolve.py    From peewee-db-evolve with GNU Lesser General Public License v3.0 6 votes vote down vote up
def add_not_null(db, migrator, table, column_name, field):
    cmds = []
    compiler = db.compiler()
    if field.default is not None:
      # if default is a function, turn it into a value
      # this won't work on columns requiring uniquiness, like UUIDs
      # as all columns will share the same called value
      default = field.default() if hasattr(field.default, '__call__') else field.default
      op = pw.Clause(pw.SQL('UPDATE'), pw.Entity(table), pw.SQL('SET'), field.as_entity(), pw.SQL('='), default, pw.SQL('WHERE'), field.as_entity(), pw.SQL('IS NULL'))
      cmds.append(compiler.parse_node(op))
    if is_postgres(db) or is_sqlite(db):
      junk = migrator.add_not_null(table, column_name, generate=True)
      cmds += normalize_whatever_junk_peewee_migrations_gives_you(migrator, junk)
      return cmds
    elif is_mysql(db):
      op = pw.Clause(pw.SQL('ALTER TABLE'), pw.Entity(table), pw.SQL('MODIFY'), compiler.field_definition(field))
      cmds.append(compiler.parse_node(op))
      return cmds
    raise Exception('how do i add a not null for %s?' % db) 
Example #4
Source File: peewee_async.py    From peewee-async with MIT License 6 votes vote down vote up
def execute_sql(self, *args, **kwargs):
        """Sync execute SQL query, `allow_sync` must be set to True.
        """
        assert self._allow_sync, (
            "Error, sync query is not allowed! Call the `.set_allow_sync()` "
            "or use the `.allow_sync()` context manager.")
        if self._allow_sync in (logging.ERROR, logging.WARNING):
            logging.log(self._allow_sync,
                        "Error, sync query is not allowed: %s %s" %
                        (str(args), str(kwargs)))
        return super().execute_sql(*args, **kwargs)


##############
# PostgreSQL #
############## 
Example #5
Source File: peewee_async.py    From peewee-async with MIT License 6 votes vote down vote up
def count(query, clear_limit=False):
    """Perform *COUNT* aggregated query asynchronously.

    :return: number of objects in ``select()`` query
    """
    clone = query.clone()
    if query._distinct or query._group_by or query._limit or query._offset:
        if clear_limit:
            clone._limit = clone._offset = None
        sql, params = clone.sql()
        wrapped = 'SELECT COUNT(1) FROM (%s) AS wrapped_select' % sql
        raw = query.model.raw(wrapped, *params)
        return (await scalar(raw)) or 0
    else:
        clone._returning = [peewee.fn.Count(peewee.SQL('*'))]
        clone._order_by = None
        return (await scalar(clone)) or 0 
Example #6
Source File: test_models.py    From aiopeewee with MIT License 5 votes vote down vote up
def test_select_all(flushdb):
    await create_users_blogs(2, 2)
    all_cols = SQL('*')
    query = Blog.select(all_cols)
    blogs = [blog async for blog in query.order_by(Blog.pk)]
    assert [b.title for b in blogs] == ['b-0-0', 'b-0-1', 'b-1-0', 'b-1-1']
    assert [(await b.user).username for b in blogs] == ['u0', 'u0', 'u1', 'u1'] 
Example #7
Source File: text.py    From quay with Apache License 2.0 5 votes vote down vote up
def prefix_search(field, prefix_query):
    """
    Returns the wildcard match for searching for the given prefix query.
    """
    # Escape the known wildcard characters.
    prefix_query = _escape_wildcard(prefix_query)
    return Field.__pow__(field, NodeList((prefix_query + "%", SQL("ESCAPE '!'")))) 
Example #8
Source File: text.py    From quay with Apache License 2.0 5 votes vote down vote up
def match_like(field, search_query):
    """
    Generates a full-text match query using an ILIKE operation, which is needed for SQLite and
    Postgres.
    """
    escaped_query = _escape_wildcard(search_query)
    clause = NodeList(("%" + escaped_query + "%", SQL("ESCAPE '!'")))
    return Field.__pow__(field, clause) 
Example #9
Source File: peeweedbevolve.py    From peewee-db-evolve with GNU Lesser General Public License v3.0 5 votes vote down vote up
def change_column_type(db, migrator, table_name, column_name, field):
    column_type = _field_type(field)
    if is_postgres(db):
      op = pw.Clause(pw.SQL('ALTER TABLE'), pw.Entity(table_name), pw.SQL('ALTER'), field.as_entity(), pw.SQL('TYPE'), field.__ddl_column__(column_type))
    elif is_mysql(db):
      op = pw.Clause(*[pw.SQL('ALTER TABLE'), pw.Entity(table_name), pw.SQL('MODIFY')] + field.__ddl__(column_type))
    else:
      raise Exception('how do i change a column type for %s?' % db)
    return normalize_whatever_junk_peewee_migrations_gives_you(migrator, op) 
Example #10
Source File: peeweedbevolve.py    From peewee-db-evolve with GNU Lesser General Public License v3.0 5 votes vote down vote up
def set_default(db, migrator, table_name, column_name, field):
    default = field.default
    if callable(default): default = default()
    param = pw.Param(field.db_value(default))
    op = pw.Clause(pw.SQL('ALTER TABLE'), pw.Entity(table_name), pw.SQL('ALTER COLUMN'), pw.Entity(column_name), pw.SQL('SET DEFAULT'), param)
    return normalize_whatever_junk_peewee_migrations_gives_you(migrator, op) 
Example #11
Source File: peeweedbevolve.py    From peewee-db-evolve with GNU Lesser General Public License v3.0 5 votes vote down vote up
def drop_default(db, migrator, table_name, column_name, field):
    op = pw.Clause(pw.SQL('ALTER TABLE'), pw.Entity(table_name), pw.SQL('ALTER COLUMN'), pw.Entity(column_name), pw.SQL('DROP DEFAULT'))
    return normalize_whatever_junk_peewee_migrations_gives_you(migrator, op) 
Example #12
Source File: peeweedbevolve.py    From peewee-db-evolve with GNU Lesser General Public License v3.0 5 votes vote down vote up
def drop_foreign_key(db, migrator, table_name, fk_name):
    drop_stmt = 'drop foreign key' if is_mysql(db) else 'DROP CONSTRAINT'
    op = pw.Clause(pw.SQL('ALTER TABLE'), pw.Entity(table_name), pw.SQL(drop_stmt), pw.Entity(fk_name))
    return normalize_whatever_junk_peewee_migrations_gives_you(migrator, op) 
Example #13
Source File: peeweedbevolve.py    From peewee-db-evolve with GNU Lesser General Public License v3.0 5 votes vote down vote up
def drop_table(migrator, table_name):
    compiler = migrator.database.compiler()
    return [compiler.parse_node(pw.Clause(pw.SQL('DROP TABLE'), pw.Entity(table_name)))] 
Example #14
Source File: peeweedbevolve.py    From peewee-db-evolve with GNU Lesser General Public License v3.0 5 votes vote down vote up
def set_default(db, migrator, table_name, column_name, field):
    default = field.default
    if callable(default): default = default()
    migration = ( migrator.make_context()
      .literal('UPDATE ').sql(pw.Entity(table_name))
      .literal(' SET ').sql(pw.Expression(pw.Entity(column_name), pw.OP.EQ, field.db_value(default), flat=True))
      .literal(' WHERE ').sql(pw.Expression(pw.Entity(column_name), pw.OP.IS, pw.SQL('NULL'), flat=True))
    )
    return extract_query_from_migration(migration) 
Example #15
Source File: query.py    From aiopeewee with MIT License 5 votes vote down vote up
def exists(self):
        clone = self.paginate(1, 1)
        clone._select = [SQL('1')]
        return bool(await clone.scalar()) 
Example #16
Source File: database.py    From aiopeewee with MIT License 5 votes vote down vote up
def default_insert_clause(self, model_class):
        return SQL('DEFAULT VALUES') 
Example #17
Source File: peewee_async.py    From peewee-async with MIT License 5 votes vote down vote up
def _run_sql(database, operation, *args, **kwargs):
    """Run SQL operation (query or command) against database.
    """
    __log__.debug((operation, args, kwargs))

    with peewee.__exception_wrapper__:
        cursor = await database.cursor_async()

        try:
            await cursor.execute(operation, *args, **kwargs)
        except:
            await cursor.release()
            raise

        return cursor 
Example #18
Source File: query.py    From torpeewee with MIT License 5 votes vote down vote up
def exists(self, database):
        clone = self.columns(SQL('1'))
        clone._limit = 1
        clone._offset = None
        return bool((await clone.scalar())) 
Example #19
Source File: query.py    From torpeewee with MIT License 5 votes vote down vote up
def count(self, database, clear_limit=False):
        clone = self.order_by().alias('_wrapped')
        if clear_limit:
            clone._limit = clone._offset = None
        try:
            if clone._having is None and clone._windows is None and \
                            clone._distinct is None and clone._simple_distinct is not True:
                clone = clone.select(SQL('1'))
        except AttributeError:
            pass
        return Select([clone], [fn.COUNT(SQL('1'))]).scalar(database) 
Example #20
Source File: repository.py    From quay with Apache License 2.0 4 votes vote down vote up
def get_visible_repositories(
    username, namespace=None, kind_filter="image", include_public=False, start_id=None, limit=None
):
    """
    Returns the repositories visible to the given user (if any).
    """
    if not include_public and not username:
        # Short circuit by returning a query that will find no repositories. We need to return a query
        # here, as it will be modified by other queries later on.
        return Repository.select(Repository.id.alias("rid")).where(Repository.id == -1)

    query = (
        Repository.select(
            Repository.name,
            Repository.id.alias("rid"),
            Repository.description,
            Namespace.username,
            Repository.visibility,
            Repository.kind,
            Repository.state,
        )
        .switch(Repository)
        .join(Namespace, on=(Repository.namespace_user == Namespace.id))
        .where(Repository.state != RepositoryState.MARKED_FOR_DELETION)
    )

    user_id = None
    if username:
        # Note: We only need the permissions table if we will filter based on a user's permissions.
        query = query.switch(Repository).distinct().join(RepositoryPermission, JOIN.LEFT_OUTER)
        found_namespace = _get_namespace_user(username)
        if not found_namespace:
            return Repository.select(Repository.id.alias("rid")).where(Repository.id == -1)

        user_id = found_namespace.id

    query = _basequery.filter_to_repos_for_user(
        query, user_id, namespace, kind_filter, include_public, start_id=start_id
    )

    if limit is not None:
        query = query.limit(limit).order_by(SQL("rid"))

    return query 
Example #21
Source File: modelutil.py    From quay with Apache License 2.0 4 votes vote down vote up
def paginate(
    query,
    model,
    descending=False,
    page_token=None,
    limit=50,
    sort_field_alias=None,
    max_page=None,
    sort_field_name=None,
):
    """
    Paginates the given query using an field range, starting at the optional page_token.

    Returns a *list* of matching results along with an unencrypted page_token for the next page, if
    any. If descending is set to True, orders by the field descending rather than ascending.
    """
    # Note: We use the sort_field_alias for the order_by, but not the where below. The alias is
    # necessary for certain queries that use unions in MySQL, as it gets confused on which field
    # to order by. The where clause, on the other hand, cannot use the alias because Postgres does
    # not allow aliases in where clauses.
    sort_field_name = sort_field_name or "id"
    sort_field = getattr(model, sort_field_name)

    if sort_field_alias is not None:
        sort_field_name = sort_field_alias
        sort_field = SQL(sort_field_alias)

    if descending:
        query = query.order_by(sort_field.desc())
    else:
        query = query.order_by(sort_field)

    start_index = pagination_start(page_token)
    if start_index is not None:
        if descending:
            query = query.where(sort_field <= start_index)
        else:
            query = query.where(sort_field >= start_index)

    query = query.limit(limit + 1)

    page_number = (page_token.get("page_number") or None) if page_token else None
    if page_number is not None and max_page is not None and page_number > max_page:
        return [], None

    return paginate_query(
        query, limit=limit, sort_field_name=sort_field_name, page_number=page_number
    ) 
Example #22
Source File: notification.py    From quay with Apache License 2.0 4 votes vote down vote up
def list_notifications(
    user, kind_name=None, id_filter=None, include_dismissed=False, page=None, limit=None
):

    base_query = Notification.select(
        Notification.id,
        Notification.uuid,
        Notification.kind,
        Notification.metadata_json,
        Notification.dismissed,
        Notification.lookup_path,
        Notification.created,
        Notification.created.alias("cd"),
        Notification.target,
    ).join(NotificationKind)

    if kind_name is not None:
        base_query = base_query.where(NotificationKind.name == kind_name)

    if id_filter is not None:
        base_query = base_query.where(Notification.uuid == id_filter)

    if not include_dismissed:
        base_query = base_query.where(Notification.dismissed == False)

    # Lookup directly for the user.
    user_direct = base_query.clone().where(Notification.target == user)

    # Lookup via organizations admined by the user.
    Org = User.alias()
    AdminTeam = Team.alias()
    AdminTeamMember = TeamMember.alias()
    AdminUser = User.alias()

    via_orgs = (
        base_query.clone()
        .join(Org, on=(Org.id == Notification.target))
        .join(AdminTeam, on=(Org.id == AdminTeam.organization))
        .join(TeamRole, on=(AdminTeam.role == TeamRole.id))
        .switch(AdminTeam)
        .join(AdminTeamMember, on=(AdminTeam.id == AdminTeamMember.team))
        .join(AdminUser, on=(AdminTeamMember.user == AdminUser.id))
        .where((AdminUser.id == user) & (TeamRole.name == "admin"))
    )

    query = user_direct | via_orgs

    if page:
        query = query.paginate(page, limit)
    elif limit:
        query = query.limit(limit)

    return query.order_by(SQL("cd desc")) 
Example #23
Source File: repository.py    From quay with Apache License 2.0 4 votes vote down vote up
def _get_sorted_matching_repositories(
    lookup_value, repo_kind="image", include_private=False, search_fields=None, ids_only=False
):
    """
    Returns a query of repositories matching the given lookup string, with optional inclusion of
    private repositories.

    Note that this method does *not* filter results based on visibility to users.
    """
    select_fields = [Repository.id] if ids_only else [Repository, Namespace]

    if not lookup_value:
        # This is a generic listing of repositories. Simply return the sorted repositories based
        # on RepositorySearchScore.
        query = (
            Repository.select(*select_fields)
            .join(RepositorySearchScore)
            .where(Repository.state != RepositoryState.MARKED_FOR_DELETION)
            .order_by(RepositorySearchScore.score.desc(), RepositorySearchScore.id)
        )
    else:
        if search_fields is None:
            search_fields = set([SEARCH_FIELDS.description.name, SEARCH_FIELDS.name.name])

        # Always search at least on name (init clause)
        clause = Repository.name.match(lookup_value)
        computed_score = RepositorySearchScore.score.alias("score")

        # If the description field is in the search fields, then we need to compute a synthetic score
        # to discount the weight of the description more than the name.
        if SEARCH_FIELDS.description.name in search_fields:
            clause = Repository.description.match(lookup_value) | clause
            cases = [
                (Repository.name.match(lookup_value), 100 * RepositorySearchScore.score),
            ]
            computed_score = Case(None, cases, RepositorySearchScore.score).alias("score")

        select_fields.append(computed_score)
        query = (
            Repository.select(*select_fields)
            .join(RepositorySearchScore)
            .where(clause)
            .where(Repository.state != RepositoryState.MARKED_FOR_DELETION)
            .order_by(SQL("score").desc(), RepositorySearchScore.id)
        )

    if repo_kind is not None:
        query = query.where(Repository.kind == Repository.kind.get_id(repo_kind))

    if not include_private:
        query = query.where(Repository.visibility == _basequery.get_public_repo_visibility())

    if not ids_only:
        query = query.switch(Repository).join(
            Namespace, on=(Namespace.id == Repository.namespace_user)
        )

    return query 
Example #24
Source File: repository.py    From quay with Apache License 2.0 4 votes vote down vote up
def _filter_repositories_visible_to_user(unfiltered_query, filter_user_id, limit, repo_kind):
    encountered = set()
    chunk_count = limit * 2
    unfiltered_page = 0
    iteration_count = 0

    while iteration_count < 10:  # Just to be safe
        # Find the next chunk's worth of repository IDs, paginated by the chunk size.
        unfiltered_page = unfiltered_page + 1
        found_ids = [r.id for r in unfiltered_query.paginate(unfiltered_page, chunk_count)]

        # Make sure we haven't encountered these results before. This code is used to handle
        # the case where we've previously seen a result, as pagination is not necessary
        # stable in SQL databases.
        unfiltered_repository_ids = set(found_ids)
        new_unfiltered_ids = unfiltered_repository_ids - encountered
        if not new_unfiltered_ids:
            break

        encountered.update(new_unfiltered_ids)

        # Filter the repositories found to only those visible to the current user.
        query = (
            Repository.select(Repository, Namespace)
            .distinct()
            .join(Namespace, on=(Namespace.id == Repository.namespace_user))
            .switch(Repository)
            .join(RepositoryPermission)
            .where(Repository.id << list(new_unfiltered_ids))
        )

        filtered = _basequery.filter_to_repos_for_user(query, filter_user_id, repo_kind=repo_kind)

        # Sort the filtered repositories by their initial order.
        all_filtered_repos = list(filtered)
        all_filtered_repos.sort(key=lambda repo: found_ids.index(repo.id))

        # Yield the repositories in sorted order.
        for filtered_repo in all_filtered_repos:
            yield filtered_repo

        # If the number of found IDs is less than the chunk count, then we're done.
        if len(found_ids) < chunk_count:
            break

        iteration_count = iteration_count + 1