Coverage for autodiff_team29/node.py: 92%

159 statements  

« prev     ^ index     » next       coverage.py v7.2.3, created at 2023-04-15 21:37 +0000

1from __future__ import annotations 

2from typing import Union 

3import warnings 

4 

5import numpy as np 

6from numpy.typing import NDArray 

7 

8 

9class Node: 

10 # other types that are capable of being converted to Node 

11 _COMPATIBLE_VALUE_TYPES = (int, float) 

12 _COMPATIBLE_DERIVATIVE_TYPES = (int, float, np.ndarray) 

13 

14 # store nodes that have been computed previously 

15 _OVERWRITE_MODE = False 

16 _NODE_REGISTRY = {} 

17 

18 # only to be used for our benchmarking example 

19 # not to be used for any other purpose 

20 _NODES_COMPUTED_FOR_BENCHMARKING = 0 

21 

22 def __new__( 

23 cls, 

24 symbol: str, 

25 value: Union[float, int], 

26 derivative: Union[int, float], 

27 **kwargs, 

28 ) -> Node: 

29 """ 

30 Represents a node which is the foundation of a computational graph. 

31 

32 Parameters 

33 ---------- 

34 symbol : str 

35 Symbolic representation of a Node instance that acts as a unique identifier. 

36 value : int, float 

37 Analytical value of the node. 

38 derivative : int, float, 

39 Derivative with respect to the value attribute 

40 

41 Optional Parameters 

42 ------------------- 

43 overwrite_existing : bool, default=True 

44 If node with matching symbol already exists, override the existing node stored in the registry 

45 

46 supress_warning : bool, default=False 

47 supresses warnings for existing nodes that are recomputed 

48 

49 **kwargs 

50 --------- 

51 seed_vector : List 

52 A seed vector for computing partial derivatives of multi-variable functions. 

53 The seed vector allows us to cherry-pick a certain derivative of interest (choose direction). 

54 For F:Rm --> Rn, our seed vector should be of length m with a 1 in the direction of interest and 0 elsewhere. 

55 

56 Examples 

57 -------- 

58 >>> x = Node('x',10,1) 

59 >>> Node('x',10,1) 

60 >>> x + x 

61 >>> Node('x+x',20,2) 

62 

63 """ 

64 # check if node already exist before recreating 

65 if cls._check_node_exists(symbol): 

66 return cls._get_existing_node(symbol) 

67 

68 # ensure that the values and derivatives specified are of the correct datatype 

69 # if they are not these methods will raise an exception 

70 cls._check_foreign_value_type_compatibility(value) 

71 cls._check_foreign_derivative_type_compatibility(derivative) 

72 

73 # creating an instance of the class 

74 instance = super().__new__(cls) 

75 instance._symbol = str(symbol) 

76 instance._value = value 

77 

78 # if kwargs are specified we are dealing with an n-dimensional function 

79 if "seed_vector" in kwargs: 

80 seed_vector = np.array(kwargs["seed_vector"]) 

81 instance._derivative = derivative * seed_vector 

82 else: 

83 instance._derivative = derivative 

84 

85 if not cls._OVERWRITE_MODE: 

86 cls._insert_node_to_registry(instance) 

87 

88 # for benchmarking purposes only 

89 cls._NODES_COMPUTED_FOR_BENCHMARKING += 1 

90 return instance 

91 

92 @property 

93 def symbol(self) -> str: 

94 """ 

95 Returns symbolic representation of the computational node 

96 

97 """ 

98 return self._symbol 

99 

100 @property 

101 def value(self) -> float | int: 

102 """ 

103 Returns analytical value of the computational node 

104 

105 """ 

106 return self._value 

107 

108 @property 

109 def derivative(self) -> int | float: 

110 """ 

111 Returns derivative value of the computational node 

112 

113 """ 

114 return self._derivative 

115 

116 @staticmethod 

117 def _check_foreign_value_type_compatibility(other_type: Union[int, float]) -> None: 

118 """ 

119 Checks to see if a datatype can be represented as a node. 

120 

121 Parameters 

122 ---------- 

123 other_type : Any 

124 Python object that will be attempt being converted to a Node 

125 

126 Raises 

127 ------- 

128 TypeError 

129 Raises TypeError if type is unsupported 

130 

131 Examples 

132 -------- 

133 >>> Node._check_foreign_derivative_type_compatibility(100) 

134 >>> Node._check_foreign_derivative_type_compatibility("100") 

135 >>> TypeError Unsupported type 'str' for value attribute in class Node 

136 

137 """ 

138 if not isinstance(other_type, Node._COMPATIBLE_VALUE_TYPES): 

139 raise TypeError( 

140 f"Unsupported type '{type(other_type)}' for value attribute in class Node" 

141 ) 

142 

143 @staticmethod 

144 def _check_foreign_derivative_type_compatibility( 

145 other_type: Union[int, float, NDArray] 

146 ) -> None: 

147 """ 

148 Checks to see if a datatype can be represented as a node. 

149 

150 Parameters 

151 ---------- 

152 other_type : Any 

153 Python object that will be attempt being converted to a Node 

154 

155 Raises 

156 ------- 

157 TypeError 

158 Raises TypeError if type is unsupported 

159 

160 Examples 

161 -------- 

162 >>> Node._check_foreign_derivative_type_compatibility(100) 

163 >>> Node._check_foreign_derivative_type_compatibility("100") 

164 >>> TypeError: Unsupported type 'str' for value attribute in class Node 

165 

166 """ 

167 if not isinstance(other_type, Node._COMPATIBLE_DERIVATIVE_TYPES): 

168 raise TypeError( 

169 f"Unsupported type '{type(other_type)}' for value attribute in class Node" 

170 ) 

171 

172 @classmethod 

173 def _convert_numeric_type_to_node(cls, to_convert: Union[int, float]) -> Node: 

174 """ 

175 Attempts to convert a numeric value into an instance of class Node. 

176 

177 Parameters 

178 ---------- 

179 to_convert : int, float 

180 Object that will convert to type Node. 

181 

182 Returns 

183 ------- 

184 Node: 

185 instance of class Node created from to_convert. 

186 

187 Raises 

188 ------ 

189 TypeError if to_convert is an unsupported data type. 

190 

191 """ 

192 if isinstance(to_convert, Node): 

193 return to_convert 

194 

195 return cls( 

196 symbol=str(to_convert), 

197 value=to_convert, 

198 derivative=0, 

199 ) 

200 

201 @staticmethod 

202 def _check_node_exists(key: str) -> bool: 

203 """ 

204 Checks if an instance of class Node has already been created. 

205 

206 Parameters 

207 ---------- 

208 key : str 

209 Symbolic representation of a Node instance that acts as a unique identifier. 

210 

211 Returns 

212 ------- 

213 bool : 

214 True if key argument is found. False otherwise. 

215 

216 """ 

217 return key in Node._NODE_REGISTRY if not Node._OVERWRITE_MODE else False 

218 

219 @staticmethod 

220 def _get_existing_node(key: str) -> Node: 

221 """ 

222 Returns existing Node instance to avoid recomputing nodes. 

223 

224 Parameters 

225 ---------- 

226 key : str 

227 Symbolic representation of a Node instance that acts as a unique identifier. 

228 

229 Returns 

230 ------- 

231 Node : 

232 instance that matches the specified key. 

233 

234 """ 

235 

236 return Node._NODE_REGISTRY[key] 

237 

238 @staticmethod 

239 def _insert_node_to_registry(node: Node) -> None: 

240 """ 

241 Adds Node instance to the registry, and allows computational graph to keep track of what nodes have 

242 already been computed . 

243 

244 Parameters 

245 ---------- 

246 node : Node 

247 Instance of class Node. 

248 

249 Returns 

250 ------- 

251 None 

252 

253 """ 

254 Node._NODE_REGISTRY[node._symbol] = node 

255 

256 @classmethod 

257 def count_nodes_stored(cls) -> int: 

258 """ 

259 Returns the number of nodes currently stored in the registry. 

260 

261 """ 

262 return len(Node._NODE_REGISTRY) 

263 

264 

265 @classmethod 

266 def set_overwrite_mode(cls, enabled: bool) -> None: 

267 """ 

268 Allows existing nodes to be recomputed. 

269 Be warned, this can result in a significant performance decrease! 

270 

271 Parameters 

272 --------- 

273 enabled : bool 

274 If true, existing nodes will be recomputed. 

275 Otherwise, existing computations will be retrieved from the node registry. 

276 """ 

277 

278 # if trying to set the mode to the current status, do nothing 

279 if cls._OVERWRITE_MODE == enabled: 

280 warnings.warn( 

281 f"Override mode is already set to {enabled}. Expect no changes", 

282 RuntimeWarning, 

283 ) 

284 return 

285 

286 # if enabling overwriting, be sure to warn user 

287 if enabled == True: 

288 warnings.warn( 

289 f"Override mode is enabled. Nodes with the same symbolic representation will be recomputed. " 

290 f"Expect potential performance decrease", 

291 RuntimeWarning, 

292 ) 

293 # clear registry when overwrite mode is enabled because we will not need it 

294 cls.clear_node_registry() 

295 

296 # if enabling is switched to false, be sure to warn th user 

297 if enabled == False: 

298 warnings.warn( 

299 f"Override mode is disabled enabled. Nodes with the same symbolic representation will not be recomputed", 

300 RuntimeWarning, 

301 ) 

302 

303 # set the overwrite mode to what the user specified 

304 cls._OVERWRITE_MODE = enabled 

305 

306 

307 @staticmethod 

308 def clear_node_registry() -> None: 

309 """ 

310 Removes all key value pairs currently stored the node registry. 

311 WARNING previous computations made by the graph will be permanently erased. 

312 

313 """ 

314 Node._NODE_REGISTRY.clear() 

315 

316 def __add__(self, other: Union[int, float, Node]) -> Node: 

317 

318 symbolic_representation = "({}+{})".format(*sorted([self._symbol, str(other)])) 

319 

320 if self._check_node_exists(symbolic_representation): 

321 return self._get_existing_node(symbolic_representation) 

322 

323 other = self._convert_numeric_type_to_node(other) 

324 primal_trace = self._value + other._value 

325 tangent_trace = self._derivative + other._derivative 

326 

327 return Node( 

328 symbolic_representation, 

329 primal_trace, 

330 tangent_trace, 

331 ) 

332 

333 def __radd__(self, other: Union[int, float]) -> Node: 

334 return self.__add__(other) 

335 

336 def __sub__(self, other: Union[int, float, Node]) -> Node: 

337 

338 symbolic_representation = "({}-{})".format(self._symbol, str(other)) 

339 

340 if self._check_node_exists(symbolic_representation): 

341 return self._get_existing_node(symbolic_representation) 

342 

343 other = self._convert_numeric_type_to_node(other) 

344 primal_trace = self._value - other._value 

345 tangent_trace = self._derivative - other._derivative 

346 

347 return Node(symbolic_representation, primal_trace, tangent_trace) 

348 

349 def __rsub__(self, other: Union[int, float]) -> Node: 

350 

351 symbolic_representation = "({}-{})".format(str(other), self._symbol) 

352 

353 if self._check_node_exists(symbolic_representation): 

354 return self._get_existing_node(symbolic_representation) 

355 

356 other = self._convert_numeric_type_to_node(other) 

357 primal_trace = other._value - self._value 

358 tangent_trace = other._derivative - self._derivative 

359 

360 return Node(symbolic_representation, primal_trace, tangent_trace) 

361 

362 def __mul__(self, other: Union[int, float, Node]) -> Node: 

363 

364 symbolic_representation = "({}*{})".format(*sorted([self._symbol, str(other)])) 

365 

366 if self._check_node_exists(symbolic_representation): 

367 return self._get_existing_node(symbolic_representation) 

368 

369 other = self._convert_numeric_type_to_node(other) 

370 primal_trace = self._value * other._value 

371 tangent_trace = ( 

372 self._value * other._derivative + other._value * self._derivative 

373 ) 

374 

375 return Node(symbolic_representation, primal_trace, tangent_trace) 

376 

377 def __rmul__(self, other: Union[int, float]) -> Node: 

378 return self.__mul__(other) 

379 

380 def __truediv__(self, other: Union[int, float, Node]) -> Node: 

381 symbolic_representation = "({}/{})".format(self._symbol, str(other)) 

382 

383 if self._check_node_exists(symbolic_representation): 

384 return self._get_existing_node(symbolic_representation) 

385 

386 other = self._convert_numeric_type_to_node(other) 

387 primal_trace = self._value / other._value 

388 tangent_trace = ( 

389 self._derivative * other._value - self._value * other._derivative 

390 ) / other._value**2 

391 

392 return Node(symbolic_representation, primal_trace, tangent_trace) 

393 

394 def __rtruediv__(self, other: Union[int, float]) -> Node: 

395 symbolic_representation = "({}/{})".format(str(other), self._symbol) 

396 

397 if self._check_node_exists(symbolic_representation): 

398 return self._get_existing_node(symbolic_representation) 

399 

400 other = self._convert_numeric_type_to_node(other) 

401 primal_trace = other._value / self._value 

402 tangent_trace = ( 

403 self._value * other._derivative - other._value * self._derivative 

404 ) / self._value**2 

405 

406 return Node(symbolic_representation, primal_trace, tangent_trace) 

407 

408 def __neg__(self) -> Node: 

409 symbolic_representation = "-{}".format(self._symbol) 

410 

411 if self._check_node_exists(symbolic_representation): 

412 return self._get_existing_node(symbolic_representation) 

413 

414 primal_trace = -1 * self._value 

415 tangent_trace = -1 * self._derivative 

416 

417 return Node(symbolic_representation, primal_trace, tangent_trace) 

418 

419 def __pow__(self, exponent: Union[int, float, Node]) -> Node: 

420 symbolic_representation = "({}**{})".format(self._symbol, str(exponent)) 

421 

422 if self._check_node_exists(symbolic_representation): 

423 return self._get_existing_node(symbolic_representation) 

424 

425 exponent = self._convert_numeric_type_to_node(exponent) 

426 primal_trace = self._value**exponent._value 

427 tangent_trace = self._value**exponent._value * ( 

428 exponent._derivative * np.log(self._value) 

429 + (self._derivative * exponent._value) / self._value 

430 ) 

431 

432 return Node(symbolic_representation, primal_trace, tangent_trace) 

433 

434 def __rpow__(self, base: Union[int, float]) -> Node: 

435 symbolic_representation = "({}**{})".format(str(base), self._symbol) 

436 

437 if self._check_node_exists(symbolic_representation): 

438 return self._get_existing_node(symbolic_representation) 

439 

440 base = self._convert_numeric_type_to_node(base) 

441 primal_trace = base._value**self._value 

442 tangent_trace = base._value**self._value * ( 

443 self._derivative * np.log(base._value) 

444 + (base._derivative * self._value) / base._value 

445 ) 

446 

447 return Node(symbolic_representation, primal_trace, tangent_trace) 

448 

449 def __str__(self) -> str: 

450 return self._symbol 

451 

452 def __repr__(self) -> str: 

453 return f"Node({self._symbol},{self._value},{self._derivative})" 

454 

455 def __eq__(self, other: Node) -> bool: 

456 symbolic_representation_equal = self._symbol == other._symbol 

457 value_equal = self._value = other._value 

458 derivative_equal = self._derivative = other._derivative 

459 

460 return all([symbolic_representation_equal, value_equal, derivative_equal])