# Copyright 2020-2024 Curtin University
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Author: Richard Hosking, James Diprose
from __future__ import annotations
from typing import List, Optional
import pendulum
from airflow import DAG
from airflow.decorators import dag, task, task_group
from airflow.models.baseoperator import chain
from airflow.operators.empty import EmptyOperator
import academic_observatory_workflows.doi_workflow.tasks as tasks
from academic_observatory_workflows.doi_workflow.queries import Aggregation, make_sql_queries, SQLQuery
from academic_observatory_workflows.doi_workflow.release import DOIRelease
from observatory_platform.airflow.airflow import on_failure_callback
from observatory_platform.airflow.release import make_snapshot_date
from observatory_platform.airflow.tasks import check_dependencies
from observatory_platform.airflow.workflow import CloudWorkspace
[docs]SENSOR_DAG_IDS = [
"crossref_fundref",
"crossref_metadata",
"openalex",
"ror",
"unpaywall",
"pubmed",
"orcid",
]
[docs]AGGREGATIONS = [
Aggregation(
"country",
"countries",
relate_to_journals=True,
relate_to_funders=True,
relate_to_publishers=True,
),
Aggregation(
"funder",
"funders",
relate_to_institutions=True,
relate_to_countries=True,
relate_to_groups=True,
relate_to_funders=True,
relate_to_publishers=True,
),
Aggregation(
"group",
"groupings",
relate_to_institutions=True,
relate_to_journals=True,
relate_to_funders=True,
relate_to_publishers=True,
),
Aggregation(
"institution",
"institutions",
relate_to_institutions=True,
relate_to_countries=True,
relate_to_journals=True,
relate_to_funders=True,
relate_to_publishers=True,
),
Aggregation(
"author",
"authors",
relate_to_institutions=True,
relate_to_countries=True,
relate_to_groups=True,
relate_to_journals=True,
relate_to_funders=True,
relate_to_publishers=True,
),
Aggregation(
"journal",
"journals",
relate_to_institutions=True,
relate_to_countries=True,
relate_to_groups=True,
relate_to_journals=True,
relate_to_funders=True,
),
Aggregation(
"publisher",
"publishers",
relate_to_institutions=True,
relate_to_countries=True,
relate_to_groups=False,
relate_to_funders=False,
),
# Aggregation(
# "region",
# "regions",
# relate_to_funders=True,
# relate_to_publishers=True,
# ),
Aggregation(
"subregion",
"subregions",
relate_to_funders=True,
relate_to_publishers=True,
),
]
[docs]class DagParams:
def __init__(
self,
dag_id: str,
cloud_workspace: CloudWorkspace,
bq_intermediate_dataset_id: str = "observatory_intermediate",
bq_dashboards_dataset_id: str = "coki_dashboards",
bq_observatory_dataset_id: str = "observatory",
bq_unpaywall_dataset_id: str = "unpaywall",
bq_ror_dataset_id: str = "ror",
api_bq_dataset_id: str = "dataset_api",
sql_queries: List[List[SQLQuery]] = None,
max_fetch_threads: int = 4,
start_date: Optional[pendulum.DateTime] = pendulum.datetime(2020, 8, 30),
schedule: Optional[str] = "@weekly",
sensor_dag_ids: List[str] = None,
max_active_runs: int = 1,
retries: int = 3,
**kwargs,
):
"""Create the DoiWorkflow.
:param dag_id: the DAG ID.
:param cloud_workspace: the cloud workspace settings.
:param bq_intermediate_dataset_id: the BigQuery intermediate dataset id.
:param bq_dashboards_dataset_id: the BigQuery dashboards dataset id.
:param bq_observatory_dataset_id: the BigQuery observatory dataset id.
:param bq_unpaywall_dataset_id: the BigQuery Unpaywall dataset id.
:param bq_ror_dataset_id: the BigQuery ROR dataset id.
:param api_bq_dataset_id: the Dataset ID to use when storing releases.
:param sql_queries: a list of batches of SQLQuery objects.
:param max_fetch_threads: maximum number of threads to use when fetching.
:param observatory_api_conn_id: the Observatory API connection key.
:param start_date: the start date.
:param schedule: the schedule interval.
:param sensor_dag_ids: a list of DAG IDs to wait for with sensors.
:param max_active_runs: the maximum number of DAG runs that can be run at once.
:param retries: the number of times to retry a task.
"""
[docs] self.cloud_workspace = cloud_workspace
[docs] self.bq_dashboards_dataset_id = bq_dashboards_dataset_id
[docs] self.bq_observatory_dataset_id = bq_observatory_dataset_id
[docs] self.bq_unpaywall_dataset_id = bq_unpaywall_dataset_id
[docs] self.bq_ror_dataset_id = bq_ror_dataset_id
[docs] self.api_bq_dataset_id = api_bq_dataset_id
[docs] self.sql_queries = sql_queries
if sql_queries is None:
self.sql_queries = make_sql_queries(
self.cloud_workspace.input_project_id,
self.cloud_workspace.output_project_id,
dataset_id_observatory=bq_observatory_dataset_id,
)
[docs] self.max_fetch_threads = max_fetch_threads
[docs] self.start_date = start_date
[docs] self.schedule = schedule
[docs] self.sensor_dag_ids = sensor_dag_ids if sensor_dag_ids is not None else SENSOR_DAG_IDS
[docs] self.max_active_runs = max_active_runs
input_table_task_ids = []
for batch in self.sql_queries:
for sql_query in batch:
task_id = sql_query.name
input_table_task_ids.append(task_id)
[docs]def create_dag(dag_params: DagParams) -> DAG:
@dag(
dag_id=dag_params.dag_id,
start_date=dag_params.start_date,
schedule=dag_params.schedule,
catchup=False,
max_active_runs=dag_params.max_active_runs,
tags=["academic-observatory"],
default_args={
"owner": "airflow",
"on_failure_callback": on_failure_callback,
"retries": dag_params.retries,
},
)
def doi():
@task_group(group_id="sensors")
def make_sensors():
"""Create the sensor tasks for the DAG."""
tasks.make_sensors(sensor_dag_ids=dag_params.sensor_dag_ids)
@task
def fetch_release(**context) -> dict:
"""Fetch a release instance."""
snapshot_date = make_snapshot_date(**context)
return DOIRelease(
dag_id=dag_params.dag_id,
run_id=context["run_id"],
snapshot_date=snapshot_date,
data_interval_start=context["data_interval_start"],
data_interval_end=context["data_interval_end"],
).to_dict()
@task
def create_datasets(release: dict, **context):
"""Create required BigQuery datasets."""
tasks.create_datasets(
output_project_id=dag_params.cloud_workspace.output_project_id,
bq_data_location=dag_params.cloud_workspace.data_location,
bq_intermediate_dataset_id=dag_params.bq_intermediate_dataset_id,
bq_dashboards_dataset_id=dag_params.bq_dashboards_dataset_id,
bq_observatory_dataset_id=dag_params.bq_observatory_dataset_id,
)
@task
def create_repo_institution_to_ror_table(release: dict, **context):
"""Create the repository_institution_to_ror_table."""
release = DOIRelease.from_dict(release)
tasks.create_repo_institution_to_ror_table(
release=release,
input_project_id=dag_params.cloud_workspace.input_project_id,
output_project_id=dag_params.cloud_workspace.output_project_id,
bq_unpaywall_dataset_id=dag_params.bq_unpaywall_dataset_id,
bq_intermediate_dataset_id=dag_params.bq_intermediate_dataset_id,
max_fetch_threads=dag_params.max_fetch_threads,
)
@task
def create_ror_hierarchy_table(release: dict, **context):
"""Create the ROR hierarchy table."""
release = DOIRelease.from_dict(release)
tasks.create_ror_hierarchy_table(
release=release,
input_project_id=dag_params.cloud_workspace.input_project_id,
output_project_id=dag_params.cloud_workspace.output_project_id,
bq_ror_dataset_id=dag_params.bq_ror_dataset_id,
bq_intermediate_dataset_id=dag_params.bq_intermediate_dataset_id,
)
@task_group(group_id="intermediate_tables")
def create_intermediate_tables(release: dict, **context):
"""Create intermediate table tasks."""
ts = []
for i, batch in enumerate(dag_params.sql_queries):
parallel = []
for sql_query in batch:
task_id = sql_query.name
t = create_intermediate_table.override(task_id=task_id)(
release,
sql_query=sql_query,
)
parallel.append(t)
merge = EmptyOperator(
task_id=f"merge_{i}",
)
ts.append(parallel)
ts.append(merge)
chain(*ts)
@task
def create_intermediate_table(release: dict, sql_query: SQLQuery, **context):
"""Create an intermediate table."""
release = DOIRelease.from_dict(release)
ti = context["ti"]
tasks.create_intermediate_table(
release=release,
sql_query=sql_query,
output_project_id=dag_params.cloud_workspace.output_project_id,
ti=ti,
)
@task_group(group_id="aggregate_tables")
def create_aggregate_tables(release: dict, **context):
"""Create aggregate table tasks."""
ts = []
for agg in AGGREGATIONS:
task_id = agg.table_name
t = create_aggregate_table.override(task_id=task_id)(
release,
aggregation=agg,
)
ts.append(t)
chain(ts)
@task
def create_aggregate_table(release: dict, aggregation: Aggregation, **context):
"""Runs the aggregate table query."""
release = DOIRelease.from_dict(release)
tasks.create_aggregate_table(
release=release,
aggregation=aggregation,
output_project_id=dag_params.cloud_workspace.output_project_id,
bq_observatory_dataset_id=dag_params.bq_observatory_dataset_id,
)
@task
def update_table_descriptions(release: dict, **context):
"""Update descriptions for tables."""
release = DOIRelease.from_dict(release)
ti = context["ti"]
tasks.update_table_descriptions(
release=release,
aggregations=AGGREGATIONS,
input_table_task_ids=dag_params.input_table_task_ids,
output_project_id=dag_params.cloud_workspace.output_project_id,
bq_observatory_dataset_id=dag_params.bq_observatory_dataset_id,
ti=ti,
)
@task
def add_dataset_release(release: dict, **context):
"""Adds release information to API."""
release = DOIRelease.from_dict(release)
tasks.add_dataset_release(
release=release,
api_bq_project_id=dag_params.cloud_workspace.project_id,
api_bq_dataset_id=dag_params.api_bq_dataset_id,
)
sensor_task_group = make_sensors()
task_check_dependencies = check_dependencies()
xcom_release = fetch_release()
create_datasets_task = create_datasets(xcom_release)
create_repo_institution_to_ror_table_task = create_repo_institution_to_ror_table(xcom_release)
create_ror_hierarchy_table_task = create_ror_hierarchy_table(xcom_release)
create_intermediate_tables_task_group = create_intermediate_tables(xcom_release)
create_aggregate_tables_task_group = create_aggregate_tables(xcom_release)
update_table_descriptions_task = update_table_descriptions(xcom_release)
add_dataset_release_task = add_dataset_release(xcom_release)
(
sensor_task_group
>> task_check_dependencies
>> xcom_release
>> create_datasets_task
>> create_repo_institution_to_ror_table_task
>> create_ror_hierarchy_table_task
>> create_intermediate_tables_task_group
>> create_aggregate_tables_task_group
>> update_table_descriptions_task
>> add_dataset_release_task
)
return doi()