Coverage for autodiff_team29/node.py: 92%
159 statements
« prev ^ index » next coverage.py v7.2.3, created at 2023-04-15 21:37 +0000
« prev ^ index » next coverage.py v7.2.3, created at 2023-04-15 21:37 +0000
1from __future__ import annotations
2from typing import Union
3import warnings
5import numpy as np
6from numpy.typing import NDArray
9class Node:
10 # other types that are capable of being converted to Node
11 _COMPATIBLE_VALUE_TYPES = (int, float)
12 _COMPATIBLE_DERIVATIVE_TYPES = (int, float, np.ndarray)
14 # store nodes that have been computed previously
15 _OVERWRITE_MODE = False
16 _NODE_REGISTRY = {}
18 # only to be used for our benchmarking example
19 # not to be used for any other purpose
20 _NODES_COMPUTED_FOR_BENCHMARKING = 0
22 def __new__(
23 cls,
24 symbol: str,
25 value: Union[float, int],
26 derivative: Union[int, float],
27 **kwargs,
28 ) -> Node:
29 """
30 Represents a node which is the foundation of a computational graph.
32 Parameters
33 ----------
34 symbol : str
35 Symbolic representation of a Node instance that acts as a unique identifier.
36 value : int, float
37 Analytical value of the node.
38 derivative : int, float,
39 Derivative with respect to the value attribute
41 Optional Parameters
42 -------------------
43 overwrite_existing : bool, default=True
44 If node with matching symbol already exists, override the existing node stored in the registry
46 supress_warning : bool, default=False
47 supresses warnings for existing nodes that are recomputed
49 **kwargs
50 ---------
51 seed_vector : List
52 A seed vector for computing partial derivatives of multi-variable functions.
53 The seed vector allows us to cherry-pick a certain derivative of interest (choose direction).
54 For F:Rm --> Rn, our seed vector should be of length m with a 1 in the direction of interest and 0 elsewhere.
56 Examples
57 --------
58 >>> x = Node('x',10,1)
59 >>> Node('x',10,1)
60 >>> x + x
61 >>> Node('x+x',20,2)
63 """
64 # check if node already exist before recreating
65 if cls._check_node_exists(symbol):
66 return cls._get_existing_node(symbol)
68 # ensure that the values and derivatives specified are of the correct datatype
69 # if they are not these methods will raise an exception
70 cls._check_foreign_value_type_compatibility(value)
71 cls._check_foreign_derivative_type_compatibility(derivative)
73 # creating an instance of the class
74 instance = super().__new__(cls)
75 instance._symbol = str(symbol)
76 instance._value = value
78 # if kwargs are specified we are dealing with an n-dimensional function
79 if "seed_vector" in kwargs:
80 seed_vector = np.array(kwargs["seed_vector"])
81 instance._derivative = derivative * seed_vector
82 else:
83 instance._derivative = derivative
85 if not cls._OVERWRITE_MODE:
86 cls._insert_node_to_registry(instance)
88 # for benchmarking purposes only
89 cls._NODES_COMPUTED_FOR_BENCHMARKING += 1
90 return instance
92 @property
93 def symbol(self) -> str:
94 """
95 Returns symbolic representation of the computational node
97 """
98 return self._symbol
100 @property
101 def value(self) -> float | int:
102 """
103 Returns analytical value of the computational node
105 """
106 return self._value
108 @property
109 def derivative(self) -> int | float:
110 """
111 Returns derivative value of the computational node
113 """
114 return self._derivative
116 @staticmethod
117 def _check_foreign_value_type_compatibility(other_type: Union[int, float]) -> None:
118 """
119 Checks to see if a datatype can be represented as a node.
121 Parameters
122 ----------
123 other_type : Any
124 Python object that will be attempt being converted to a Node
126 Raises
127 -------
128 TypeError
129 Raises TypeError if type is unsupported
131 Examples
132 --------
133 >>> Node._check_foreign_derivative_type_compatibility(100)
134 >>> Node._check_foreign_derivative_type_compatibility("100")
135 >>> TypeError Unsupported type 'str' for value attribute in class Node
137 """
138 if not isinstance(other_type, Node._COMPATIBLE_VALUE_TYPES):
139 raise TypeError(
140 f"Unsupported type '{type(other_type)}' for value attribute in class Node"
141 )
143 @staticmethod
144 def _check_foreign_derivative_type_compatibility(
145 other_type: Union[int, float, NDArray]
146 ) -> None:
147 """
148 Checks to see if a datatype can be represented as a node.
150 Parameters
151 ----------
152 other_type : Any
153 Python object that will be attempt being converted to a Node
155 Raises
156 -------
157 TypeError
158 Raises TypeError if type is unsupported
160 Examples
161 --------
162 >>> Node._check_foreign_derivative_type_compatibility(100)
163 >>> Node._check_foreign_derivative_type_compatibility("100")
164 >>> TypeError: Unsupported type 'str' for value attribute in class Node
166 """
167 if not isinstance(other_type, Node._COMPATIBLE_DERIVATIVE_TYPES):
168 raise TypeError(
169 f"Unsupported type '{type(other_type)}' for value attribute in class Node"
170 )
172 @classmethod
173 def _convert_numeric_type_to_node(cls, to_convert: Union[int, float]) -> Node:
174 """
175 Attempts to convert a numeric value into an instance of class Node.
177 Parameters
178 ----------
179 to_convert : int, float
180 Object that will convert to type Node.
182 Returns
183 -------
184 Node:
185 instance of class Node created from to_convert.
187 Raises
188 ------
189 TypeError if to_convert is an unsupported data type.
191 """
192 if isinstance(to_convert, Node):
193 return to_convert
195 return cls(
196 symbol=str(to_convert),
197 value=to_convert,
198 derivative=0,
199 )
201 @staticmethod
202 def _check_node_exists(key: str) -> bool:
203 """
204 Checks if an instance of class Node has already been created.
206 Parameters
207 ----------
208 key : str
209 Symbolic representation of a Node instance that acts as a unique identifier.
211 Returns
212 -------
213 bool :
214 True if key argument is found. False otherwise.
216 """
217 return key in Node._NODE_REGISTRY if not Node._OVERWRITE_MODE else False
219 @staticmethod
220 def _get_existing_node(key: str) -> Node:
221 """
222 Returns existing Node instance to avoid recomputing nodes.
224 Parameters
225 ----------
226 key : str
227 Symbolic representation of a Node instance that acts as a unique identifier.
229 Returns
230 -------
231 Node :
232 instance that matches the specified key.
234 """
236 return Node._NODE_REGISTRY[key]
238 @staticmethod
239 def _insert_node_to_registry(node: Node) -> None:
240 """
241 Adds Node instance to the registry, and allows computational graph to keep track of what nodes have
242 already been computed .
244 Parameters
245 ----------
246 node : Node
247 Instance of class Node.
249 Returns
250 -------
251 None
253 """
254 Node._NODE_REGISTRY[node._symbol] = node
256 @classmethod
257 def count_nodes_stored(cls) -> int:
258 """
259 Returns the number of nodes currently stored in the registry.
261 """
262 return len(Node._NODE_REGISTRY)
265 @classmethod
266 def set_overwrite_mode(cls, enabled: bool) -> None:
267 """
268 Allows existing nodes to be recomputed.
269 Be warned, this can result in a significant performance decrease!
271 Parameters
272 ---------
273 enabled : bool
274 If true, existing nodes will be recomputed.
275 Otherwise, existing computations will be retrieved from the node registry.
276 """
278 # if trying to set the mode to the current status, do nothing
279 if cls._OVERWRITE_MODE == enabled:
280 warnings.warn(
281 f"Override mode is already set to {enabled}. Expect no changes",
282 RuntimeWarning,
283 )
284 return
286 # if enabling overwriting, be sure to warn user
287 if enabled == True:
288 warnings.warn(
289 f"Override mode is enabled. Nodes with the same symbolic representation will be recomputed. "
290 f"Expect potential performance decrease",
291 RuntimeWarning,
292 )
293 # clear registry when overwrite mode is enabled because we will not need it
294 cls.clear_node_registry()
296 # if enabling is switched to false, be sure to warn th user
297 if enabled == False:
298 warnings.warn(
299 f"Override mode is disabled enabled. Nodes with the same symbolic representation will not be recomputed",
300 RuntimeWarning,
301 )
303 # set the overwrite mode to what the user specified
304 cls._OVERWRITE_MODE = enabled
307 @staticmethod
308 def clear_node_registry() -> None:
309 """
310 Removes all key value pairs currently stored the node registry.
311 WARNING previous computations made by the graph will be permanently erased.
313 """
314 Node._NODE_REGISTRY.clear()
316 def __add__(self, other: Union[int, float, Node]) -> Node:
318 symbolic_representation = "({}+{})".format(*sorted([self._symbol, str(other)]))
320 if self._check_node_exists(symbolic_representation):
321 return self._get_existing_node(symbolic_representation)
323 other = self._convert_numeric_type_to_node(other)
324 primal_trace = self._value + other._value
325 tangent_trace = self._derivative + other._derivative
327 return Node(
328 symbolic_representation,
329 primal_trace,
330 tangent_trace,
331 )
333 def __radd__(self, other: Union[int, float]) -> Node:
334 return self.__add__(other)
336 def __sub__(self, other: Union[int, float, Node]) -> Node:
338 symbolic_representation = "({}-{})".format(self._symbol, str(other))
340 if self._check_node_exists(symbolic_representation):
341 return self._get_existing_node(symbolic_representation)
343 other = self._convert_numeric_type_to_node(other)
344 primal_trace = self._value - other._value
345 tangent_trace = self._derivative - other._derivative
347 return Node(symbolic_representation, primal_trace, tangent_trace)
349 def __rsub__(self, other: Union[int, float]) -> Node:
351 symbolic_representation = "({}-{})".format(str(other), self._symbol)
353 if self._check_node_exists(symbolic_representation):
354 return self._get_existing_node(symbolic_representation)
356 other = self._convert_numeric_type_to_node(other)
357 primal_trace = other._value - self._value
358 tangent_trace = other._derivative - self._derivative
360 return Node(symbolic_representation, primal_trace, tangent_trace)
362 def __mul__(self, other: Union[int, float, Node]) -> Node:
364 symbolic_representation = "({}*{})".format(*sorted([self._symbol, str(other)]))
366 if self._check_node_exists(symbolic_representation):
367 return self._get_existing_node(symbolic_representation)
369 other = self._convert_numeric_type_to_node(other)
370 primal_trace = self._value * other._value
371 tangent_trace = (
372 self._value * other._derivative + other._value * self._derivative
373 )
375 return Node(symbolic_representation, primal_trace, tangent_trace)
377 def __rmul__(self, other: Union[int, float]) -> Node:
378 return self.__mul__(other)
380 def __truediv__(self, other: Union[int, float, Node]) -> Node:
381 symbolic_representation = "({}/{})".format(self._symbol, str(other))
383 if self._check_node_exists(symbolic_representation):
384 return self._get_existing_node(symbolic_representation)
386 other = self._convert_numeric_type_to_node(other)
387 primal_trace = self._value / other._value
388 tangent_trace = (
389 self._derivative * other._value - self._value * other._derivative
390 ) / other._value**2
392 return Node(symbolic_representation, primal_trace, tangent_trace)
394 def __rtruediv__(self, other: Union[int, float]) -> Node:
395 symbolic_representation = "({}/{})".format(str(other), self._symbol)
397 if self._check_node_exists(symbolic_representation):
398 return self._get_existing_node(symbolic_representation)
400 other = self._convert_numeric_type_to_node(other)
401 primal_trace = other._value / self._value
402 tangent_trace = (
403 self._value * other._derivative - other._value * self._derivative
404 ) / self._value**2
406 return Node(symbolic_representation, primal_trace, tangent_trace)
408 def __neg__(self) -> Node:
409 symbolic_representation = "-{}".format(self._symbol)
411 if self._check_node_exists(symbolic_representation):
412 return self._get_existing_node(symbolic_representation)
414 primal_trace = -1 * self._value
415 tangent_trace = -1 * self._derivative
417 return Node(symbolic_representation, primal_trace, tangent_trace)
419 def __pow__(self, exponent: Union[int, float, Node]) -> Node:
420 symbolic_representation = "({}**{})".format(self._symbol, str(exponent))
422 if self._check_node_exists(symbolic_representation):
423 return self._get_existing_node(symbolic_representation)
425 exponent = self._convert_numeric_type_to_node(exponent)
426 primal_trace = self._value**exponent._value
427 tangent_trace = self._value**exponent._value * (
428 exponent._derivative * np.log(self._value)
429 + (self._derivative * exponent._value) / self._value
430 )
432 return Node(symbolic_representation, primal_trace, tangent_trace)
434 def __rpow__(self, base: Union[int, float]) -> Node:
435 symbolic_representation = "({}**{})".format(str(base), self._symbol)
437 if self._check_node_exists(symbolic_representation):
438 return self._get_existing_node(symbolic_representation)
440 base = self._convert_numeric_type_to_node(base)
441 primal_trace = base._value**self._value
442 tangent_trace = base._value**self._value * (
443 self._derivative * np.log(base._value)
444 + (base._derivative * self._value) / base._value
445 )
447 return Node(symbolic_representation, primal_trace, tangent_trace)
449 def __str__(self) -> str:
450 return self._symbol
452 def __repr__(self) -> str:
453 return f"Node({self._symbol},{self._value},{self._derivative})"
455 def __eq__(self, other: Node) -> bool:
456 symbolic_representation_equal = self._symbol == other._symbol
457 value_equal = self._value = other._value
458 derivative_equal = self._derivative = other._derivative
460 return all([symbolic_representation_equal, value_equal, derivative_equal])