blob: 152862ff2b393c5d7dac3e7541e22ed7133004cd [file] [log] [blame]
Illyoung Choid1e4f5d2019-07-22 16:49:20 -07001#!/usr/bin/env python3
2
3# Copyright 2019-present Open Networking Foundation
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17from airflow.plugins_manager import AirflowPlugin
18from airflow.hooks.base_hook import BaseHook
19from airflow.operators.python_operator import PythonOperator
20from airflow.sensors.base_sensor_operator import BaseSensorOperator
21from airflow.utils.decorators import apply_defaults
22from cord_workflow_controller_client.workflow_run import WorkflowRun
23
24
25"""
26Airflow Hook
27"""
28
29
30class CORDWorkflowControllerException(Exception):
31 """
32 Alias for Exception.
33 """
34
35
36class CORDWorkflowControllerHook(BaseHook):
37 """
38 Hook for accessing CORD Workflow Controller
39 """
40
41 def __init__(
42 self,
43 workflow_id,
44 workflow_run_id,
45 controller_conn_id='cord_controller_default'):
46 super().__init__(source=None)
47 self.workflow_id = workflow_id
48 self.workflow_run_id = workflow_run_id
49 self.controller_conn_id = controller_conn_id
50
51 self.workflow_run_client = None
52
53 def __enter__(self):
54 return self
55
56 def __exit__(self, exc_type, exc_val, exc_tb):
57 if self.workflow_run_client is not None:
58 self.close_conn()
59
60 def get_conn(self):
61 """
62 Connect a Workflow Run client.
63 """
64 if self.workflow_run_client is None:
65 # find connection info from database or environment
66 # ENV: AIRFLOW_CONN_CORD_CONTROLLER_DEFAULT
67 connection_params = self.get_connection(self.controller_conn_id)
68 # connection_params have three fields
69 # host
70 # login - we don't use this yet
71 # password - we don't use this yet
72 try:
73 self.workflow_run_client = WorkflowRun(self.workflow_id, self.workflow_run_id)
74 self.workflow_run_client.connect(connection_params.host)
75 except BaseException as ex:
76 raise CORDWorkflowControllerException(ex)
77
78 return self.workflow_run_client
79
80 def close_conn(self):
81 """
82 Close the Workflow Run client
83 """
84 if self.workflow_run_client:
85 try:
86 self.workflow_run_client.disconnect()
87 except BaseException as ex:
88 raise CORDWorkflowControllerException(ex)
89
90 self.workflow_run_client = None
91
92 def update_status(self, task_id, status):
93 """
94 Update status of the workflow run.
95 'state' should be one of ['begin', 'end']
96 """
97 client = self.get_conn()
98 try:
99 return client.update_status(task_id, status)
100 except BaseException as ex:
101 raise CORDWorkflowControllerException(ex)
102
103 def count_events(self):
104 """
105 Count queued events for the workflow run.
106 """
107 client = self.get_conn()
108 try:
109 return client.count_events()
110 except BaseException as ex:
111 raise CORDWorkflowControllerException(ex)
112
113 def fetch_event(self, task_id, topic):
114 """
115 Fetch an event for the workflow run.
116 """
117 client = self.get_conn()
118 try:
119 return client.fetch_event(task_id, topic)
120 except BaseException as ex:
121 raise CORDWorkflowControllerException(ex)
122
123
124"""
125Airflow Operators
126"""
127
128
129class CORDModelOperator(PythonOperator):
130 """
131 Calls a python function with model accessor.
132 """
133
134 # SCARLET
135 # http://bootflat.github.io/color-picker.html
136 ui_color = '#cf3a24'
137
138 @apply_defaults
139 def __init__(
140 self,
141 python_callable,
142 cord_event_sensor_task_id=None,
143 op_args=None,
144 op_kwargs=None,
145 provide_context=True,
146 templates_dict=None,
147 templates_exts=None,
148 *args,
149 **kwargs
150 ):
151 super().__init__(
152 python_callable=python_callable,
153 op_args=op_args,
154 op_kwargs=op_kwargs,
155 provide_context=True,
156 templates_dict=templates_dict,
157 templates_exts=templates_exts,
158 *args,
159 **kwargs)
160 self.cord_event_sensor_task_id = cord_event_sensor_task_id
161
162 def execute_callable(self):
163 # TODO
164 model_accessor = None
165
166 message = None
167 if self.cord_event_sensor_task_id:
168 message = self.op_kwargs['ti'].xcom_pull(task_ids=self.cord_event_sensor_task_id)
169
170 new_op_kwargs = dict(self.op_kwargs, model_accessor=model_accessor, message=message)
171 return self.python_callable(*self.op_args, **new_op_kwargs)
172
173
174"""
175Airflow Sensors
176"""
177
178
179class CORDEventSensor(BaseSensorOperator):
180 # STEEL BLUE
181 # http://bootflat.github.io/color-picker.html
182 ui_color = '#4b77be'
183
184 @apply_defaults
185 def __init__(
186 self,
187 topic,
188 key_field,
189 controller_conn_id='cord_controller_default',
190 *args,
191 **kwargs):
192 super().__init__(*args, **kwargs)
193
194 self.topic = topic
195 self.key_field = key_field
196 self.controller_conn_id = controller_conn_id
197 self.message = None
198 self.hook = None
199
200 def __create_hook(self, context):
201 """
202 Return connection hook.
203 """
204 return CORDWorkflowControllerHook(self.dag_id, context['dag_run'].run_id, self.controller_conn_id)
205
206 def execute(self, context):
207 """
208 Overridden to allow messages to be passed to next tasks via XCOM
209 """
210 if self.hook is None:
211 self.hook = self.__create_hook(context)
212
213 self.hook.update_status(self.task_id, 'begin')
214
215 super().execute(context)
216
217 self.hook.update_status(self.task_id, 'end')
218 self.hook.close_conn()
219 self.hook = None
220 return self.message
221
222 def poke(self, context):
223 # we need to use notification to immediately react at event
224 # https://github.com/apache/airflow/blob/master/airflow/sensors/base_sensor_operator.py#L122
225 self.log.info('Poking : trying to fetch a message with a topic %s', self.topic)
226 event = self.hook.fetch_event(self.task_id, self.topic)
227 if event:
228 self.message = event
229 return True
230 return False
231
232
233class CORDModelSensor(CORDEventSensor):
234 # SISKIN SPROUT YELLOW
235 # http://bootflat.github.io/color-picker.html
236 ui_color = '#7a942e'
237
238 @apply_defaults
239 def __init__(
240 self,
241 model_name,
242 key_field,
243 controller_conn_id='cord_controller_default',
244 *args,
245 **kwargs):
246 topic = 'datamodel.%s' % model_name
247 super().__init__(topic=topic, *args, **kwargs)
248
249
250"""
251Airflow Plugin Definition
252"""
253
254
255# Defining the plugin class
256class CORD_Workflow_Airflow_Plugin(AirflowPlugin):
257 name = "CORD_Workflow_Airflow_Plugin"
258 operators = [CORDModelOperator]
259 sensors = [CORDEventSensor, CORDModelSensor]
260 hooks = [CORDWorkflowControllerHook]
261 executors = []
262 macros = []
263 admin_views = []
264 flask_blueprints = []
265 menu_links = []
266 appbuilder_views = []
267 appbuilder_menu_items = []