Python asyncpg.Connection() Examples

The following are 30 code examples of asyncpg.Connection(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module asyncpg , or try the search function .
Example #1
Source File: database_migration.py    From NabBot with Apache License 2.0 6 votes vote down vote up
def import_roles(conn: asyncpg.Connection, c: sqlite3.Cursor):
    log.info("Importing roles...")
    auto_roles = []
    joinable_roles = []
    log.debug("Gathering auto roles from sqlite...")
    c.execute("SELECT server_id, role_id, guild FROM auto_roles")
    rows = c.fetchall()
    for server_id, role_id, guild in rows:
        auto_roles.append((server_id, role_id, guild))
    log.debug(f"Collected {len(auto_roles):,} records from old database.")
    log.info("Copying records to auto roles table")
    res = await conn.copy_records_to_table("role_auto", records=auto_roles, columns=["server_id", "role_id", "rule"])
    log.info(f"Copied {get_affected_count(res):,} records successfully.")

    log.debug("Gathering joinable roles from sqlite...")
    c.execute("SELECT server_id, role_id FROM joinable_roles")
    rows = c.fetchall()
    for server_id, role_id in rows:
        joinable_roles.append((server_id, role_id))
    log.debug(f"Collected {len(joinable_roles):,} records from old database.")
    log.info("Copying records to joinable roles table")
    res = await conn.copy_records_to_table("role_joinable", records=joinable_roles, columns=["server_id", "role_id"])
    log.info(f"Copied {get_affected_count(res):,} records successfully.")
    log.info("Finished importing roles.") 
Example #2
Source File: database.py    From NabBot with Apache License 2.0 6 votes vote down vote up
def get_latest(cls, conn: PoolConn, *, minimum_level=0, user_id=0, worlds: Union[List[str], str] = None):
        """Gets an asynchronous generator of the character's level ups.

        :param conn: Connection to the database.
        :param minimum_level: The minimum level to show.
        :param user_id: The id of an user to only show level ups of characters they own.
        :param worlds: A list of worlds to only show level ups of characters in that world.
        :return: An asynchronous generator containing the levels.
        """
        if isinstance(worlds, str):
            worlds = [worlds]
        if not worlds:
            worlds = []
        async with conn.transaction():
            async for row in conn.cursor("""
                    SELECT l.*, (json_agg(c)->>0)::jsonb as char FROM character_levelup l
                    LEFT JOIN "character" c ON c.id = l.character_id
                    WHERE ($1::bigint = 0 OR c.user_id = $1) AND (cardinality($2::text[]) = 0 OR c.world = any($2))
                    AND l.level >= $3
                    GROUP BY l.id
                    ORDER BY date DESC""", user_id, worlds, minimum_level):
                yield cls(**row) 
Example #3
Source File: database.py    From NabBot with Apache License 2.0 6 votes vote down vote up
def update_world(self, conn: PoolConn, world: str, update_self=True) -> bool:
        """Updates the world of the character on the database.

        :param conn: Connection to the database.
        :param world: The new world to set.
        :param update_self: Whether to also update the object or not.
        :return: Whether the world was updated in the database or not.
        """
        result = await self.update_field_by_id(conn, self.id, "world", world)
        if result and update_self:
            self.world = world
        return result is not None

    # endregion

    # region Class methods 
Example #4
Source File: tracking.py    From NabBot with Apache License 2.0 6 votes vote down vote up
def get_by_name(cls, conn: PoolConn, channel_id: int, name: str, is_guild: bool) -> \
            Optional['WatchlistEntry']:
        """Gets an entry by its name.

        :param conn: Connection to the database.
        :param channel_id: The id of the channel.
        :param name: Name of the entry.
        :param is_guild: Whether the entry is a guild or a character.
        :return: The entry if found.
        """
        row = await conn.fetchrow("SELECT * FROM watchlist_entry "
                                  "WHERE channel_id = $1 AND lower(name) = $2 AND is_guild = $3",
                                  channel_id, name.lower().strip(), is_guild)
        if row is None:
            return None
        return cls(**row) 
Example #5
Source File: tracking.py    From NabBot with Apache License 2.0 6 votes vote down vote up
def insert(cls, conn: PoolConn, channel_id: int, name: str, is_guild: bool, user_id: int, reason=None)\
            -> Optional['WatchlistEntry']:
        """Inserts a watchlist entry into the database.

        :param conn: Connection to the database.
        :param channel_id: The id of the watchlist's channel.
        :param name: Name of the entry.
        :param is_guild:  Whether the entry is a guild or a character.
        :param user_id: The id of the user that added the entry.
        :param reason: The reason for the entry.
        :return: The inserted entry.
        """
        row = await conn.fetchrow("INSERT INTO watchlist_entry(channel_id, name, is_guild, reason, user_id) "
                                  "VALUES($1, $2, $3, $4, $5) RETURNING *", channel_id, name, is_guild, reason, user_id)
        if row is None:
            return None
        return cls(**row)

# endregion 
Example #6
Source File: connection.py    From asyncpg with Apache License 2.0 6 votes vote down vote up
def add_listener(self, channel, callback):
        """Add a listener for Postgres notifications.

        :param str channel: Channel to listen on.

        :param callable callback:
            A callable receiving the following arguments:
            **connection**: a Connection the callback is registered with;
            **pid**: PID of the Postgres server that sent the notification;
            **channel**: name of the channel the notification was sent to;
            **payload**: the payload.
        """
        self._check_open()
        if channel not in self._listeners:
            await self.fetch('LISTEN {}'.format(utils._quote_ident(channel)))
            self._listeners[channel] = set()
        self._listeners[channel].add(callback) 
Example #7
Source File: connection.py    From asyncpg with Apache License 2.0 6 votes vote down vote up
def add_log_listener(self, callback):
        """Add a listener for Postgres log messages.

        It will be called when asyncronous NoticeResponse is received
        from the connection.  Possible message types are: WARNING, NOTICE,
        DEBUG, INFO, or LOG.

        :param callable callback:
            A callable receiving the following arguments:
            **connection**: a Connection the callback is registered with;
            **message**: the `exceptions.PostgresLogMessage` message.

        .. versionadded:: 0.12.0
        """
        if self.is_closed():
            raise exceptions.InterfaceError('connection is closed')
        self._log_listeners.add(callback) 
Example #8
Source File: __init__.py    From opentelemetry-python with Apache License 2.0 6 votes vote down vote up
def _instrument(self, **kwargs):
        tracer_provider = kwargs.get(
            "tracer_provider", trace.get_tracer_provider()
        )
        setattr(
            asyncpg,
            _APPLIED,
            tracer_provider.get_tracer("asyncpg", __version__),
        )
        for method in [
            "Connection.execute",
            "Connection.executemany",
            "Connection.fetch",
            "Connection.fetchval",
            "Connection.fetchrow",
        ]:
            wrapt.wrap_function_wrapper(
                "asyncpg.connection", method, _do_execute
            ) 
Example #9
Source File: compiler.py    From edgedb with Apache License 2.0 6 votes vote down vote up
def _load_reflection_cache(
        self,
        connection: asyncpg.Connection,
    ) -> FrozenSet[str]:
        data = await connection.fetch('''
            SELECT
                eql_hash,
                argnames
            FROM
                ROWS FROM(edgedb._get_cached_reflection())
                    AS t(eql_hash text, argnames text[])
        ''')

        return immutables.Map({
            r['eql_hash']: tuple(r['argnames']) for r in data
        }) 
Example #10
Source File: compiler.py    From edgedb with Apache License 2.0 6 votes vote down vote up
def ensure_initialized(self, con: asyncpg.Connection) -> None:
        if self._std_schema is None:
            self._std_schema = await load_cached_schema(con, 'stdschema')

        if self._refl_schema is None:
            self._refl_schema = await load_cached_schema(con, 'reflschema')

        if self._schema_class_layout is None:
            self._schema_class_layout = await load_schema_class_layout(con)

        if self._intro_query is None:
            self._intro_query = await load_schema_intro_query(con)

        if self._config_spec is None:
            self._config_spec = config.load_spec_from_schema(
                self._std_schema)
            config.set_settings(self._config_spec) 
Example #11
Source File: pgutils.py    From openmaptiles-tools with MIT License 5 votes vote down vote up
def __init__(self, conn: Connection, delay_printing=False) -> None:
        self.messages = []
        self.delay_printing = delay_printing
        conn.add_log_listener(lambda _, msg: self.on_warning(msg)) 
Example #12
Source File: test_pool.py    From asyncpg with Apache License 2.0 5 votes vote down vote up
def test_pool_config_persistence(self):
        N = 100
        cons = set()

        class MyConnection(asyncpg.Connection):
            async def foo(self):
                return 42

            async def fetchval(self, query):
                res = await super().fetchval(query)
                return res + 1

        async def test(pool):
            async with pool.acquire() as con:
                self.assertEqual(await con.fetchval('SELECT 1'), 2)
                self.assertEqual(await con.foo(), 42)
                self.assertTrue(isinstance(con, MyConnection))
                self.assertEqual(con._con._config.statement_cache_size, 3)
                cons.add(con)

        async with self.create_pool(
                database='postgres', min_size=10, max_size=10,
                max_queries=1, connection_class=MyConnection,
                statement_cache_size=3) as pool:

            await asyncio.gather(*[test(pool) for _ in range(N)])

        self.assertEqual(len(cons), N) 
Example #13
Source File: tracking.py    From NabBot with Apache License 2.0 5 votes vote down vote up
def get_by_channel_id(cls, conn: PoolConn, channel_id: int) -> Optional['Watchlist']:
        """Gets a watchlist corresponding to the channel id.

        :param conn: Connection to the database.
        :param channel_id: The id of the channel.
        :return: The found watchlist, if any."""
        row = await conn.fetchrow("SELECT * FROM watchlist WHERE channel_id = $1", channel_id)
        if row is None:
            return None
        return cls(**row) 
Example #14
Source File: tracking.py    From NabBot with Apache License 2.0 5 votes vote down vote up
def remove(self, conn: PoolConn):
        """Removes a watchlist entry from the database.

        :param conn: Connection to the database.
        """
        await self.delete(conn, self.channel_id, self.name, self.is_guild) 
Example #15
Source File: connection.py    From asyncpg with Apache License 2.0 5 votes vote down vote up
def _cleanup_stmts(self):
        # Called whenever we create a new prepared statement in
        # `Connection._get_statement()` and `_stmts_to_close` is
        # not empty.
        to_close = self._stmts_to_close
        self._stmts_to_close = set()
        for stmt in to_close:
            # It is imperative that statements are cleaned properly,
            # so we ignore the timeout.
            await self._protocol.close_statement(stmt, protocol.NO_TIMEOUT) 
Example #16
Source File: connection.py    From asyncpg with Apache License 2.0 5 votes vote down vote up
def executemany(self, command: str, args, *, timeout: float=None):
        """Execute an SQL *command* for each sequence of arguments in *args*.

        Example:

        .. code-block:: pycon

            >>> await con.executemany('''
            ...     INSERT INTO mytab (a) VALUES ($1, $2, $3);
            ... ''', [(1, 2, 3), (4, 5, 6)])

        :param command: Command to execute.
        :param args: An iterable containing sequences of arguments.
        :param float timeout: Optional timeout value in seconds.
        :return None: This method discards the results of the operations.

        .. note::

           When inserting a large number of rows,
           use :meth:`Connection.copy_records_to_table()` instead,
           it is much more efficient for this purpose.

        .. versionadded:: 0.7.0

        .. versionchanged:: 0.11.0
           `timeout` became a keyword-only parameter.
        """
        self._check_open()
        return await self._executemany(command, args, timeout) 
Example #17
Source File: connection.py    From asyncpg with Apache License 2.0 5 votes vote down vote up
def _cleanup(self):
        # Free the resources associated with this connection.
        # This must be called when a connection is terminated.

        if self._proxy is not None:
            # Connection is a member of a pool, so let the pool
            # know that this connection is dead.
            self._proxy._holder._release_on_close()

        self._mark_stmts_as_closed()
        self._listeners.clear()
        self._log_listeners.clear()
        self._clean_tasks() 
Example #18
Source File: connection.py    From asyncpg with Apache License 2.0 5 votes vote down vote up
def _maybe_gc_stmt(self, stmt):
        if stmt.refs == 0 and not self._stmt_cache.has(stmt.query):
            # If low-level `stmt` isn't referenced from any high-level
            # `PreparedStatement` object and is not in the `_stmt_cache`:
            #
            #  * mark it as closed, which will make it non-usable
            #    for any `PreparedStatement` or for methods like
            #    `Connection.fetch()`.
            #
            # * schedule it to be formally closed on the server.
            stmt.mark_closed()
            self._stmts_to_close.add(stmt) 
Example #19
Source File: connection.py    From asyncpg with Apache License 2.0 5 votes vote down vote up
def is_in_transaction(self):
        """Return True if Connection is currently inside a transaction.

        :return bool: True if inside transaction, False otherwise.

        .. versionadded:: 0.16.0
        """
        return self._protocol.is_in_transaction() 
Example #20
Source File: connection.py    From asyncpg with Apache License 2.0 5 votes vote down vote up
def _unwrap(self):
        if self._proxy is None:
            con_ref = self
        else:
            # `_proxy` is not None when the connection is a member
            # of a connection pool.  Which means that the user is working
            # with a `PoolConnectionProxy` instance, and expects to see it
            # (and not the actual Connection) in their event callbacks.
            con_ref = self._proxy
        return con_ref 
Example #21
Source File: tracking.py    From NabBot with Apache License 2.0 5 votes vote down vote up
def insert(cls, conn: PoolConn, server_id: int, channel_id: int, user_id: int) -> 'Watchlist':
        """Adds a new watchlist to the database.

        :param conn: Connection to the database.
        :param server_id: The discord guild's id.
        :param channel_id: The channel's id.
        :param user_id: The user that created the watchlist.
        :return: The created watchlist.
        """
        row = await conn.fetchrow("INSERT INTO watchlist(server_id, channel_id, user_id) VALUES($1,$2,$3) RETURNING *",
                                  server_id, channel_id, user_id)
        return cls(**row) 
Example #22
Source File: tracking.py    From NabBot with Apache License 2.0 5 votes vote down vote up
def update_show_count(self, conn: PoolConn, show_count: bool):
        """Update's the show_count property.

        If the property is True, the number of online entries will be shown in the channel's name.

        :param conn: Connection to the database.
        :param show_count: The property's new value.
        """
        await conn.execute("UPDATE watchlist SET show_count = $1 WHERE channel_id = $2", show_count, self.channel_id)
        self.show_count = show_count 
Example #23
Source File: tracking.py    From NabBot with Apache License 2.0 5 votes vote down vote up
def update_message_id(self, conn: PoolConn, message_id: int):
        """Update's the message id.

        :param conn: Connection to the database.
        :param message_id: The new message id.
        """
        await conn.execute("UPDATE watchlist SET message_id = $1 WHERE channel_id = $2", message_id, self.channel_id)
        self.message_id = message_id 
Example #24
Source File: tracking.py    From NabBot with Apache License 2.0 5 votes vote down vote up
def get_entries(self, conn: PoolConn) -> List['WatchlistEntry']:
        """Gets all entries in this watchlist.

        :param conn: Connection to the database.
        :return: List of entries if any.
        """
        return await WatchlistEntry.get_entries_by_channel(conn, self.channel_id) 
Example #25
Source File: database_migration.py    From NabBot with Apache License 2.0 5 votes vote down vote up
def import_ignored_channels(conn: asyncpg.Connection, c: sqlite3.Cursor):
    log.info("Importing ignored channels...")
    channels = []
    log.debug("Gathering ignored channels from sqlite...")
    c.execute("SELECT server_id, channel_id FROM ignored_channels")
    rows = c.fetchall()
    for server_id, channel_id in rows:
        channels.append((server_id, channel_id))
    log.debug(f"Collected {len(channels):,} records from old database.")
    log.info("Copying records to ignored channels table")
    res = await conn.copy_records_to_table("ignored_entry", records=channels, columns=["server_id", "entry_id"])
    log.info(f"Copied {get_affected_count(res):,} records successfully.")
    log.info("Finished importing channels.") 
Example #26
Source File: database.py    From NabBot with Apache License 2.0 5 votes vote down vote up
def get_recent_timeline(conn: PoolConn, *, minimum_level=0, user_id=0, worlds: Union[List[str], str] = None):
    """Gets an asynchronous generator of recent deaths and level ups

    :param conn: Connection to the database.
    :param minimum_level: The minimum level to show.
    :param user_id: The id of an user to only show entries of characters they own.
    :param worlds: A list of worlds to only show entries of characters in that world.
    :return: An asynchronous generator containing the entries.
    """
    if isinstance(worlds, str):
        worlds = [worlds]
    if not worlds:
        worlds = []
    async with conn.transaction():
        async for row in conn.cursor(f"""
                (
                    SELECT d.*, (json_agg(c)->>0)::jsonb as char, json_agg(k)::jsonb as killers, 'd' AS type
                    FROM character_death d
                    LEFT JOIN {DbKiller.table} k ON k.death_id = d.id
                    LEFT JOIN "character" c ON c.id = d.character_id
                    WHERE ($1::bigint = 0 OR c.user_id = $1) AND
                    (cardinality($2::text[]) = 0 OR c.world = any($2)) AND d.level >= $3
                    GROUP BY d.id
                )
                UNION
                (
                    SELECT l.*, (json_agg(c)->>0)::jsonb as char, NULL, 'l' AS type
                    FROM character_levelup l
                    LEFT JOIN "character" c ON c.id = l.character_id
                    WHERE ($1::bigint = 0 OR c.user_id = $1) AND
                    (cardinality($2::text[]) = 0 OR c.world = any($2)) AND l.level >= $3
                    GROUP BY l.id
                )
                ORDER by date DESC
                """, user_id, worlds, minimum_level):
            if row["type"] == "l":
                yield DbLevelUp(**row)
            else:
                yield DbDeath(**row) 
Example #27
Source File: database_migration.py    From NabBot with Apache License 2.0 5 votes vote down vote up
def import_server_properties(conn: asyncpg.Connection, c: sqlite3.Cursor):
    properties = []
    prefixes = []
    times = []
    log.debug("Gathering server property records from sqlite...")
    log.info("Importing server properties...")
    c.execute("SELECT server_id, name, value FROM server_properties")
    rows = c.fetchall()
    for server_id, key, value in rows:
        server_id = int(server_id)
        if key == "prefixes":
            prefixes.append((server_id, json.loads(value)))
            continue
        if key == "times":
            value = json.loads(value)
            for entry in value:
                times.append((server_id, entry["timezone"], entry["name"]))
            continue
        elif key in ["events_channel", "levels_channel", "news_channel", "welcome_channel", "ask_channel",
                     "announce_channel", "announce_level"]:
            value = int(value)
        elif key in ["watched_message", "watched_channel"]:
            continue
        elif key == "commandsonly":
            value = bool(value)
        properties.append((server_id, key, value))
    log.debug(f"Collected {len(properties):,} properties, {len(times):,} timezones and {len(prefixes):,} prefixes"
              f" from old database.")
    log.info("Copying records to server property table")
    res = await conn.copy_records_to_table("server_property", records=properties, columns=["server_id", "key", "value"])
    log.info(f"Copied {get_affected_count(res):,} records successfully.")

    log.info("Copying records to server prefixes table")
    res = await conn.copy_records_to_table("server_prefixes", records=prefixes, columns=["server_id", "prefixes"])
    log.info(f"Copied {get_affected_count(res):,} records successfully.")

    log.info("Copying records to server timezone table")
    res = await conn.copy_records_to_table("server_timezone", records=times, columns=["server_id", "zone", "name"])
    log.info(f"Copied {get_affected_count(res):,} records successfully.")
    log.info("Finished importing server properties.") 
Example #28
Source File: database_migration.py    From NabBot with Apache License 2.0 5 votes vote down vote up
def set_version(con: asyncpg.connection.Connection, version):
    """Sets the database's version."""
    await con.execute("""
        INSERT INTO global_property (key, value) VALUES ('db_version',$1)
        ON CONFLICT (key)
        DO
         UPDATE
           SET value = EXCLUDED.value;
    """, version) 
Example #29
Source File: database_migration.py    From NabBot with Apache License 2.0 5 votes vote down vote up
def create_database(con: asyncpg.connection.Connection):
    """Creates NabBot's tables and functions."""
    log.info("Creating tables...")
    for create_query in tables:
        await con.execute(create_query)
    log.info("Creating functions...")
    for f in functions:
        await con.execute(f)
    log.info("Creating triggers...")
    for trigger in triggers:
        await con.execute(trigger)
    log.info("Setting version to 1...")
    await set_version(con, LATEST_VERSION) 
Example #30
Source File: tracking.py    From NabBot with Apache License 2.0 5 votes vote down vote up
def delete(cls, conn: PoolConn, channel_id: int, name: str, is_guild: bool):
        """

        :param conn: Connection to the databse.
        :param channel_id: The id of the watchlist's channel.
        :param name: The name of the entry.
        :param is_guild: Whether the entry is a guild or a character.
        """
        await conn.execute("DELETE FROM watchlist_entry WHERE channel_id = $1 AND lower(name) = $2 AND is_guild = $3",
                           channel_id, name.lower().strip(), is_guild)