Python airflow.settings.Session() Examples

The following are 30 code examples of airflow.settings.Session(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module airflow.settings , or try the search function .
Example #1
Source File: test_pool.py    From airflow with Apache License 2.0 7 votes vote down vote up
def test_infinite_slots(self):
        pool = Pool(pool='test_pool', slots=-1)
        dag = DAG(
            dag_id='test_infinite_slots',
            start_date=DEFAULT_DATE, )
        op1 = DummyOperator(task_id='dummy1', dag=dag, pool='test_pool')
        op2 = DummyOperator(task_id='dummy2', dag=dag, pool='test_pool')
        ti1 = TI(task=op1, execution_date=DEFAULT_DATE)
        ti2 = TI(task=op2, execution_date=DEFAULT_DATE)
        ti1.state = State.RUNNING
        ti2.state = State.QUEUED

        session = settings.Session
        session.add(pool)
        session.add(ti1)
        session.add(ti2)
        session.commit()
        session.close()

        self.assertEqual(float('inf'), pool.open_slots())  # pylint: disable=no-value-for-parameter
        self.assertEqual(1, pool.running_slots())  # pylint: disable=no-value-for-parameter
        self.assertEqual(1, pool.queued_slots())  # pylint: disable=no-value-for-parameter
        self.assertEqual(2, pool.occupied_slots())  # pylint: disable=no-value-for-parameter 
Example #2
Source File: test_taskinstance.py    From airflow with Apache License 2.0 6 votes vote down vote up
def test_task_stats(self, stats_mock):
        dag = DAG('test_task_start_end_stats', start_date=DEFAULT_DATE,
                  end_date=DEFAULT_DATE + datetime.timedelta(days=10))
        op = DummyOperator(task_id='dummy_op', dag=dag)
        dag.create_dagrun(
            run_id='manual__' + DEFAULT_DATE.isoformat(),
            execution_date=DEFAULT_DATE,
            start_date=DEFAULT_DATE,
            state=State.RUNNING,
            external_trigger=False)
        ti = TI(task=op, execution_date=DEFAULT_DATE)
        ti.state = State.RUNNING
        session = settings.Session()
        session.merge(ti)
        session.commit()
        ti._run_raw_task()
        ti.refresh_from_db()
        stats_mock.assert_called_with('ti.finish.{}.{}.{}'.format(dag.dag_id, op.task_id, ti.state))
        self.assertIn(call('ti.start.{}.{}'.format(dag.dag_id, op.task_id)), stats_mock.mock_calls)
        self.assertEqual(stats_mock.call_count, 5) 
Example #3
Source File: test_trigger_rule_dep.py    From airflow with Apache License 2.0 6 votes vote down vote up
def test_all_done_tr_success(self):
        """
        All-done trigger rule success
        """
        ti = self._get_task_instance(TriggerRule.ALL_DONE,
                                     upstream_task_ids=["FakeTaskID",
                                                        "OtherFakeTaskID"])
        dep_statuses = tuple(TriggerRuleDep()._evaluate_trigger_rule(
            ti=ti,
            successes=2,
            skipped=0,
            failed=0,
            upstream_failed=0,
            done=2,
            flag_upstream_failed=False,
            session="Fake Session"))
        self.assertEqual(len(dep_statuses), 0) 
Example #4
Source File: gcp_authenticator.py    From airflow with Apache License 2.0 6 votes vote down vote up
def set_dictionary_in_airflow_connection(self):
        """
        Set dictionary in 'google_cloud_default' connection to contain content
        of the json service account file.
        :return: None
        """
        session = settings.Session()
        try:
            conn = session.query(Connection).filter(Connection.conn_id == 'google_cloud_default')[0]
            extras = conn.extra_dejson
            with open(self.full_key_path, "r") as path_file:
                content = json.load(path_file)
            extras[KEYFILE_DICT_EXTRA] = json.dumps(content)
            if extras.get(KEYPATH_EXTRA):
                del extras[KEYPATH_EXTRA]
            extras[SCOPE_EXTRA] = 'https://www.googleapis.com/auth/cloud-platform'
            extras[PROJECT_EXTRA] = self.project_extra
            conn.extra = json.dumps(extras)
            session.commit()
        except BaseException as ex:
            self.log.error('Airflow DB Session error: %s', str(ex))
            session.rollback()
            raise
        finally:
            session.close() 
Example #5
Source File: test_trigger_rule_dep.py    From airflow with Apache License 2.0 6 votes vote down vote up
def test_all_failed_tr_success(self):
        """
        All-failed trigger rule success
        """
        ti = self._get_task_instance(TriggerRule.ALL_FAILED,
                                     upstream_task_ids=["FakeTaskID",
                                                        "OtherFakeTaskID"])
        dep_statuses = tuple(TriggerRuleDep()._evaluate_trigger_rule(
            ti=ti,
            successes=0,
            skipped=0,
            failed=2,
            upstream_failed=0,
            done=2,
            flag_upstream_failed=False,
            session="Fake Session"))
        self.assertEqual(len(dep_statuses), 0) 
Example #6
Source File: test_trigger_rule_dep.py    From airflow with Apache License 2.0 6 votes vote down vote up
def test_all_done_tr_failure(self):
        """
        All-done trigger rule failure
        """
        ti = self._get_task_instance(TriggerRule.ALL_DONE,
                                     upstream_task_ids=["FakeTaskID",
                                                        "OtherFakeTaskID"])
        dep_statuses = tuple(TriggerRuleDep()._evaluate_trigger_rule(
            ti=ti,
            successes=1,
            skipped=0,
            failed=0,
            upstream_failed=0,
            done=1,
            flag_upstream_failed=False,
            session="Fake Session"))
        self.assertEqual(len(dep_statuses), 1)
        self.assertFalse(dep_statuses[0].passed) 
Example #7
Source File: test_trigger_rule_dep.py    From airflow with Apache License 2.0 6 votes vote down vote up
def test_unknown_tr(self):
        """
        Unknown trigger rules should cause this dep to fail
        """
        ti = self._get_task_instance()
        ti.task.trigger_rule = "Unknown Trigger Rule"
        dep_statuses = tuple(TriggerRuleDep()._evaluate_trigger_rule(
            ti=ti,
            successes=1,
            skipped=0,
            failed=0,
            upstream_failed=0,
            done=1,
            flag_upstream_failed=False,
            session="Fake Session"))

        self.assertEqual(len(dep_statuses), 1)
        self.assertFalse(dep_statuses[0].passed) 
Example #8
Source File: test_dag_command.py    From airflow with Apache License 2.0 6 votes vote down vote up
def test_delete_dag(self):
        DM = DagModel
        key = "my_dag_id"
        session = settings.Session()
        session.add(DM(dag_id=key))
        session.commit()
        dag_command.dag_delete(self.parser.parse_args([
            'dags', 'delete', key, '--yes']))
        self.assertEqual(session.query(DM).filter_by(dag_id=key).count(), 0)
        self.assertRaises(
            AirflowException,
            dag_command.dag_delete,
            self.parser.parse_args([
                'dags', 'delete',
                'does_not_exist_dag',
                '--yes'])
        ) 
Example #9
Source File: test_trigger_rule_dep.py    From airflow with Apache License 2.0 6 votes vote down vote up
def test_all_failed_tr_failure(self):
        """
        All-failed trigger rule failure
        """
        ti = self._get_task_instance(TriggerRule.ALL_FAILED,
                                     upstream_task_ids=["FakeTaskID",
                                                        "OtherFakeTaskID"])
        dep_statuses = tuple(TriggerRuleDep()._evaluate_trigger_rule(
            ti=ti,
            successes=2,
            skipped=0,
            failed=0,
            upstream_failed=0,
            done=2,
            flag_upstream_failed=False,
            session="Fake Session"))
        self.assertEqual(len(dep_statuses), 1)
        self.assertFalse(dep_statuses[0].passed) 
Example #10
Source File: test_trigger_rule_dep.py    From airflow with Apache License 2.0 6 votes vote down vote up
def test_none_failed_or_skipped_tr_success(self):
        """
        All success including skip trigger rule success
        """
        ti = self._get_task_instance(TriggerRule.NONE_FAILED_OR_SKIPPED,
                                     upstream_task_ids=["FakeTaskID",
                                                        "OtherFakeTaskID"])
        dep_statuses = tuple(TriggerRuleDep()._evaluate_trigger_rule(
            ti=ti,
            successes=1,
            skipped=1,
            failed=0,
            upstream_failed=0,
            done=2,
            flag_upstream_failed=False,
            session="Fake Session"))
        self.assertEqual(len(dep_statuses), 0) 
Example #11
Source File: test_trigger_rule_dep.py    From airflow with Apache License 2.0 6 votes vote down vote up
def test_none_failed_tr_failure(self):
        """
        All success including skip trigger rule failure
        """
        ti = self._get_task_instance(TriggerRule.NONE_FAILED,
                                     upstream_task_ids=["FakeTaskID",
                                                        "OtherFakeTaskID",
                                                        "FailedFakeTaskID"])
        dep_statuses = tuple(TriggerRuleDep()._evaluate_trigger_rule(
            ti=ti,
            successes=1,
            skipped=1,
            failed=1,
            upstream_failed=0,
            done=3,
            flag_upstream_failed=False,
            session="Fake Session"))
        self.assertEqual(len(dep_statuses), 1)
        self.assertFalse(dep_statuses[0].passed) 
Example #12
Source File: test_trigger_rule_dep.py    From airflow with Apache License 2.0 6 votes vote down vote up
def test_none_failed_tr_success(self):
        """
        All success including skip trigger rule success
        """
        ti = self._get_task_instance(TriggerRule.NONE_FAILED,
                                     upstream_task_ids=["FakeTaskID",
                                                        "OtherFakeTaskID"])
        dep_statuses = tuple(TriggerRuleDep()._evaluate_trigger_rule(
            ti=ti,
            successes=1,
            skipped=1,
            failed=0,
            upstream_failed=0,
            done=2,
            flag_upstream_failed=False,
            session="Fake Session"))
        self.assertEqual(len(dep_statuses), 0) 
Example #13
Source File: test_trigger_rule_dep.py    From airflow with Apache License 2.0 6 votes vote down vote up
def test_all_success_tr_failure(self):
        """
        All-success trigger rule failure
        """
        ti = self._get_task_instance(TriggerRule.ALL_SUCCESS,
                                     upstream_task_ids=["FakeTaskID",
                                                        "OtherFakeTaskID"])
        dep_statuses = tuple(TriggerRuleDep()._evaluate_trigger_rule(
            ti=ti,
            successes=1,
            skipped=0,
            failed=1,
            upstream_failed=0,
            done=2,
            flag_upstream_failed=False,
            session="Fake Session"))
        self.assertEqual(len(dep_statuses), 1)
        self.assertFalse(dep_statuses[0].passed) 
Example #14
Source File: test_trigger_rule_dep.py    From airflow with Apache License 2.0 6 votes vote down vote up
def test_one_failure_tr_failure(self):
        """
        One-failure trigger rule failure
        """
        ti = self._get_task_instance(TriggerRule.ONE_FAILED)
        dep_statuses = tuple(TriggerRuleDep()._evaluate_trigger_rule(
            ti=ti,
            successes=2,
            skipped=0,
            failed=0,
            upstream_failed=0,
            done=2,
            flag_upstream_failed=False,
            session="Fake Session"))
        self.assertEqual(len(dep_statuses), 1)
        self.assertFalse(dep_statuses[0].passed) 
Example #15
Source File: test_trigger_rule_dep.py    From airflow with Apache License 2.0 6 votes vote down vote up
def test_all_success_tr_skip(self):
        """
        All-success trigger rule fails when some upstream tasks are skipped.
        """
        ti = self._get_task_instance(TriggerRule.ALL_SUCCESS,
                                     upstream_task_ids=["FakeTaskID",
                                                        "OtherFakeTaskID"])
        dep_statuses = tuple(TriggerRuleDep()._evaluate_trigger_rule(
            ti=ti,
            successes=1,
            skipped=1,
            failed=0,
            upstream_failed=0,
            done=2,
            flag_upstream_failed=False,
            session="Fake Session"))
        self.assertEqual(len(dep_statuses), 1)
        self.assertFalse(dep_statuses[0].passed) 
Example #16
Source File: test_taskinstance.py    From airflow with Apache License 2.0 6 votes vote down vote up
def test_execute_callback(self):
        called = False

        def on_execute_callable(context):
            nonlocal called
            called = True
            self.assertEqual(
                context['dag_run'].dag_id,
                'test_dagrun_execute_callback'
            )

        dag = DAG('test_execute_callback', start_date=DEFAULT_DATE,
                  end_date=DEFAULT_DATE + datetime.timedelta(days=10))
        task = DummyOperator(task_id='op', email='test@test.test',
                             on_execute_callback=on_execute_callable,
                             dag=dag)
        ti = TI(task=task, execution_date=datetime.datetime.now())
        ti.state = State.RUNNING
        session = settings.Session()
        session.merge(ti)
        session.commit()
        ti._run_raw_task()
        assert called
        ti.refresh_from_db()
        assert ti.state == State.SUCCESS 
Example #17
Source File: test_dagrun.py    From airflow with Apache License 2.0 6 votes vote down vote up
def test_clear_task_instances_for_backfill_dagrun(self):
        now = timezone.utcnow()
        session = settings.Session()
        dag_id = 'test_clear_task_instances_for_backfill_dagrun'
        dag = DAG(dag_id=dag_id, start_date=now)
        self.create_dag_run(dag, execution_date=now, is_backfill=True)

        task0 = DummyOperator(task_id='backfill_task_0', owner='test', dag=dag)
        ti0 = TI(task=task0, execution_date=now)
        ti0.run()

        qry = session.query(TI).filter(
            TI.dag_id == dag.dag_id).all()
        clear_task_instances(qry, session)
        session.commit()
        ti0.refresh_from_db()
        dr0 = session.query(DagRun).filter(
            DagRun.dag_id == dag_id,
            DagRun.execution_date == now
        ).first()
        self.assertEqual(dr0.state, State.RUNNING) 
Example #18
Source File: test_taskinstance.py    From airflow with Apache License 2.0 6 votes vote down vote up
def test_success_callback_no_race_condition(self):
        callback_wrapper = CallbackWrapper()
        dag = DAG('test_success_callback_no_race_condition', start_date=DEFAULT_DATE,
                  end_date=DEFAULT_DATE + datetime.timedelta(days=10))
        task = DummyOperator(task_id='op', email='test@test.test',
                             on_success_callback=callback_wrapper.success_handler, dag=dag)
        ti = TI(task=task, execution_date=datetime.datetime.now())
        ti.state = State.RUNNING
        session = settings.Session()
        session.merge(ti)
        session.commit()
        callback_wrapper.wrap_task_instance(ti)
        ti._run_raw_task()
        self.assertTrue(callback_wrapper.callback_ran)
        self.assertEqual(callback_wrapper.task_state_in_callback, State.RUNNING)
        ti.refresh_from_db()
        self.assertEqual(ti.state, State.SUCCESS) 
Example #19
Source File: test_dagrun.py    From airflow with Apache License 2.0 6 votes vote down vote up
def test_dagrun_no_deadlock_with_shutdown(self):
        session = settings.Session()
        dag = DAG('test_dagrun_no_deadlock_with_shutdown',
                  start_date=DEFAULT_DATE)
        with dag:
            op1 = DummyOperator(task_id='upstream_task')
            op2 = DummyOperator(task_id='downstream_task')
            op2.set_upstream(op1)

        dr = dag.create_dagrun(run_id='test_dagrun_no_deadlock_with_shutdown',
                               state=State.RUNNING,
                               execution_date=DEFAULT_DATE,
                               start_date=DEFAULT_DATE)
        upstream_ti = dr.get_task_instance(task_id='upstream_task')
        upstream_ti.set_state(State.SHUTDOWN, session=session)

        dr.update_state()
        self.assertEqual(dr.state, State.RUNNING) 
Example #20
Source File: test_variable.py    From airflow with Apache License 2.0 6 votes vote down vote up
def test_var_with_encryption_rotate_fernet_key(self, test_value):
        """
        Tests rotating encrypted variables.
        """
        key1 = Fernet.generate_key()
        key2 = Fernet.generate_key()

        with conf_vars({('core', 'fernet_key'): key1.decode()}):
            Variable.set('key', test_value)
            session = settings.Session()
            test_var = session.query(Variable).filter(Variable.key == 'key').one()
            self.assertTrue(test_var.is_encrypted)
            self.assertEqual(test_var.val, test_value)
            self.assertEqual(Fernet(key1).decrypt(test_var._val.encode()), test_value.encode())

        # Test decrypt of old value with new key
        with conf_vars({('core', 'fernet_key'): ','.join([key2.decode(), key1.decode()])}):
            crypto._fernet = None
            self.assertEqual(test_var.val, test_value)

            # Test decrypt of new value with new key
            test_var.rotate_fernet_key()
            self.assertTrue(test_var.is_encrypted)
            self.assertEqual(test_var.val, test_value)
            self.assertEqual(Fernet(key2).decrypt(test_var._val.encode()), test_value.encode()) 
Example #21
Source File: test_sentry.py    From airflow with Apache License 2.0 6 votes vote down vote up
def setUp(self):

        self.sentry = ConfiguredSentry()

        # Mock the Dag
        self.dag = Mock(dag_id=DAG_ID, params=[])
        self.dag.task_ids = [TASK_ID]

        # Mock the task
        self.task = Mock(dag=self.dag, dag_id=DAG_ID, task_id=TASK_ID, params=[], pool_slots=1)
        self.task.__class__.__name__ = OPERATOR

        self.ti = TaskInstance(self.task, execution_date=EXECUTION_DATE)
        self.ti.operator = OPERATOR
        self.ti.state = STATE

        self.dag.get_task_instances = MagicMock(return_value=[self.ti])

        self.session = Session() 
Example #22
Source File: test_dag.py    From airflow with Apache License 2.0 6 votes vote down vote up
def test_dag_is_deactivated_upon_dagfile_deletion(self):
        dag_id = 'old_existing_dag'
        dag_fileloc = "/usr/local/airflow/dags/non_existing_path.py"
        dag = DAG(
            dag_id,
            is_paused_upon_creation=True,
        )
        dag.fileloc = dag_fileloc
        session = settings.Session()
        dag.sync_to_db(session=session)

        orm_dag = session.query(DagModel).filter(DagModel.dag_id == dag_id).one()

        self.assertTrue(orm_dag.is_active)
        self.assertEqual(orm_dag.fileloc, dag_fileloc)

        DagModel.deactivate_deleted_dags(list_py_file_paths(settings.DAGS_FOLDER))

        orm_dag = session.query(DagModel).filter(DagModel.dag_id == dag_id).one()
        self.assertFalse(orm_dag.is_active)

        # pylint: disable=no-member
        session.execute(DagModel.__table__.delete().where(DagModel.dag_id == dag_id))
        session.close() 
Example #23
Source File: test_trigger_rule_dep.py    From airflow with Apache License 2.0 6 votes vote down vote up
def test_one_success_tr_failure(self):
        """
        One-success trigger rule failure
        """
        ti = self._get_task_instance(TriggerRule.ONE_SUCCESS)
        dep_statuses = tuple(TriggerRuleDep()._evaluate_trigger_rule(
            ti=ti,
            successes=0,
            skipped=2,
            failed=2,
            upstream_failed=2,
            done=2,
            flag_upstream_failed=False,
            session="Fake Session"))
        self.assertEqual(len(dep_statuses), 1)
        self.assertFalse(dep_statuses[0].passed) 
Example #24
Source File: test_dag.py    From airflow with Apache License 2.0 6 votes vote down vote up
def test_sync_to_db_default_view(self, mock_now):
        dag = DAG(
            'dag',
            start_date=DEFAULT_DATE,
            default_view="graph",
        )
        with dag:
            DummyOperator(task_id='task', owner='owner1')
            SubDagOperator(
                task_id='subtask',
                owner='owner2',
                subdag=DAG(
                    'dag.subtask',
                    start_date=DEFAULT_DATE,
                )
            )
        now = datetime.datetime.utcnow().replace(tzinfo=pendulum.timezone('UTC'))
        mock_now.return_value = now
        session = settings.Session()
        dag.sync_to_db(session=session)

        orm_dag = session.query(DagModel).filter(DagModel.dag_id == 'dag').one()
        self.assertIsNotNone(orm_dag.default_view)
        self.assertEqual(orm_dag.default_view, "graph")
        session.close() 
Example #25
Source File: test_pool.py    From airflow with Apache License 2.0 6 votes vote down vote up
def test_default_pool_open_slots(self):
        set_default_pool_slots(5)
        self.assertEqual(5, Pool.get_default_pool().open_slots())

        dag = DAG(
            dag_id='test_default_pool_open_slots',
            start_date=DEFAULT_DATE, )
        op1 = DummyOperator(task_id='dummy1', dag=dag)
        op2 = DummyOperator(task_id='dummy2', dag=dag, pool_slots=2)
        ti1 = TI(task=op1, execution_date=DEFAULT_DATE)
        ti2 = TI(task=op2, execution_date=DEFAULT_DATE)
        ti1.state = State.RUNNING
        ti2.state = State.QUEUED

        session = settings.Session
        session.add(ti1)
        session.add(ti2)
        session.commit()
        session.close()

        self.assertEqual(2, Pool.get_default_pool().open_slots()) 
Example #26
Source File: test_backfill_job.py    From airflow with Apache License 2.0 5 votes vote down vote up
def test_backfill_run_backwards(self):
        dag = self.dagbag.get_dag("test_start_date_scheduling")
        dag.clear()

        executor = MockExecutor(parallelism=16)

        job = BackfillJob(
            executor=executor,
            dag=dag,
            start_date=DEFAULT_DATE,
            end_date=DEFAULT_DATE + datetime.timedelta(days=1),
            run_backwards=True
        )
        job.run()

        session = settings.Session()
        tis = session.query(TI).filter(
            TI.dag_id == 'test_start_date_scheduling' and TI.task_id == 'dummy'
        ).order_by(TI.execution_date).all()

        queued_times = [ti.queued_dttm for ti in tis]
        self.assertTrue(queued_times == sorted(queued_times, reverse=True))
        self.assertTrue(all([ti.state == State.SUCCESS for ti in tis]))

        dag.clear()
        session.close() 
Example #27
Source File: test_dagrun.py    From airflow with Apache License 2.0 5 votes vote down vote up
def test_get_task_instance_on_empty_dagrun(self):
        """
        Make sure that a proper value is returned when a dagrun has no task instances
        """
        dag = DAG(
            dag_id='test_get_task_instance_on_empty_dagrun',
            start_date=timezone.datetime(2017, 1, 1)
        )
        ShortCircuitOperator(
            task_id='test_short_circuit_false',
            dag=dag,
            python_callable=lambda: False)

        session = settings.Session()

        now = timezone.utcnow()

        # Don't use create_dagrun since it will create the task instances too which we
        # don't want
        dag_run = models.DagRun(
            dag_id=dag.dag_id,
            run_type=DagRunType.MANUAL.value,
            execution_date=now,
            start_date=now,
            state=State.RUNNING,
            external_trigger=False,
        )
        session.add(dag_run)
        session.commit()

        ti = dag_run.get_task_instance('test_short_circuit_false')
        self.assertEqual(None, ti) 
Example #28
Source File: test_dagrun.py    From airflow with Apache License 2.0 5 votes vote down vote up
def test_dagrun_no_deadlock_with_depends_on_past(self):
        session = settings.Session()
        dag = DAG('test_dagrun_no_deadlock',
                  start_date=DEFAULT_DATE)
        with dag:
            DummyOperator(task_id='dop', depends_on_past=True)
            DummyOperator(task_id='tc', task_concurrency=1)

        dag.clear()
        dr = dag.create_dagrun(run_id='test_dagrun_no_deadlock_1',
                               state=State.RUNNING,
                               execution_date=DEFAULT_DATE,
                               start_date=DEFAULT_DATE)
        dr2 = dag.create_dagrun(run_id='test_dagrun_no_deadlock_2',
                                state=State.RUNNING,
                                execution_date=DEFAULT_DATE + datetime.timedelta(days=1),
                                start_date=DEFAULT_DATE + datetime.timedelta(days=1))
        ti1_op1 = dr.get_task_instance(task_id='dop')
        dr2.get_task_instance(task_id='dop')
        ti2_op1 = dr.get_task_instance(task_id='tc')
        dr.get_task_instance(task_id='tc')
        ti1_op1.set_state(state=State.RUNNING, session=session)
        dr.update_state()
        dr2.update_state()
        self.assertEqual(dr.state, State.RUNNING)
        self.assertEqual(dr2.state, State.RUNNING)

        ti2_op1.set_state(state=State.RUNNING, session=session)
        dr.update_state()
        dr2.update_state()
        self.assertEqual(dr.state, State.RUNNING)
        self.assertEqual(dr2.state, State.RUNNING) 
Example #29
Source File: test_backfill_job.py    From airflow with Apache License 2.0 5 votes vote down vote up
def test_backfill_pooled_tasks(self):
        """
        Test that queued tasks are executed by BackfillJob
        """
        session = settings.Session()
        pool = Pool(pool='test_backfill_pooled_task_pool', slots=1)
        session.add(pool)
        session.commit()
        session.close()

        dag = self.dagbag.get_dag('test_backfill_pooled_task_dag')
        dag.clear()

        executor = MockExecutor(do_update=True)
        job = BackfillJob(
            dag=dag,
            start_date=DEFAULT_DATE,
            end_date=DEFAULT_DATE,
            executor=executor)

        # run with timeout because this creates an infinite loop if not
        # caught
        try:
            with timeout(seconds=5):
                job.run()
        except AirflowTaskTimeout:
            pass
        ti = TI(
            task=dag.get_task('test_backfill_pooled_task'),
            execution_date=DEFAULT_DATE)
        ti.refresh_from_db()
        self.assertEqual(ti.state, State.SUCCESS) 
Example #30
Source File: test_local_task_job.py    From airflow with Apache License 2.0 5 votes vote down vote up
def test_localtaskjob_heartbeat(self, mock_pid):
        session = settings.Session()
        dag = DAG(
            'test_localtaskjob_heartbeat',
            start_date=DEFAULT_DATE,
            default_args={'owner': 'owner1'})

        with dag:
            op1 = DummyOperator(task_id='op1')

        dag.clear()
        dr = dag.create_dagrun(run_id="test",
                               state=State.SUCCESS,
                               execution_date=DEFAULT_DATE,
                               start_date=DEFAULT_DATE,
                               session=session)
        ti = dr.get_task_instance(task_id=op1.task_id, session=session)
        ti.state = State.RUNNING
        ti.hostname = "blablabla"
        session.commit()

        job1 = LocalTaskJob(task_instance=ti,
                            ignore_ti_state=True,
                            executor=SequentialExecutor())
        self.assertRaises(AirflowException, job1.heartbeat_callback)

        mock_pid.return_value = 1
        ti.state = State.RUNNING
        ti.hostname = get_hostname()
        ti.pid = 1
        session.merge(ti)
        session.commit()

        job1.heartbeat_callback(session=None)

        mock_pid.return_value = 2
        self.assertRaises(AirflowException, job1.heartbeat_callback)