ICON Community Interface 0.4.0
Loading...
Searching...
No Matches
gpu_test.py
Go to the documentation of this file.
1import comin
2import numpy as np
3
5if glob.has_device:
6 comin.print_info(f"{glob.device_name=}")
7 comin.print_info(f"{glob.device_vendor=}")
8 comin.print_info(f"{glob.device_driver=}")
9
10if glob.has_device and "NVIDIA" in glob.device_vendor.upper():
11 try:
12 comin.print_info("Using cupy!")
13 import cupy as xp
14
15 DEVICE_SYNC_FLAG = comin.COMIN_FLAG_DEVICE
16 except ImportError as e:
17 comin.print_info("Cannot import cupy, falling back to numpy")
18 comin.print_info(e)
19 import sys
20
21 comin.print_info(sys.path)
22 import numpy as xp
23
24 DEVICE_SYNC_FLAG = 0
25else:
26 comin.print_info("No NVIDIA device found falling back to numpy")
27 import numpy as xp
28
29 DEVICE_SYNC_FLAG = 0
30
32
33comin.var_request_add(("device_to_host", 1), True)
34comin.var_request_add(("host_to_device", 1), True)
35
36
37@comin.register_callback(comin.EP_SECONDARY_CONSTRUCTOR)
39 global ta, device_to_host1, device_to_host2, host_to_device1, host_to_device2
40 ta = comin.var_get(
41 [comin.EP_ATM_WRITE_OUTPUT_BEFORE], ("temp", 1), comin.COMIN_FLAG_READ
42 )
43 device_to_host1 = comin.var_get(
44 [comin.EP_ATM_NUDGING_BEFORE],
45 ("device_to_host", 1),
46 comin.COMIN_FLAG_WRITE | DEVICE_SYNC_FLAG,
47 )
48 host_to_device1 = comin.var_get(
49 [comin.EP_ATM_NUDGING_BEFORE],
50 ("host_to_device", 1),
51 comin.COMIN_FLAG_WRITE,
52 )
53 device_to_host2 = comin.var_get(
54 [comin.EP_ATM_NUDGING_AFTER], ("device_to_host", 1), comin.COMIN_FLAG_READ
55 )
56 host_to_device2 = comin.var_get(
57 [comin.EP_ATM_NUDGING_AFTER],
58 ("host_to_device", 1),
59 comin.COMIN_FLAG_READ | DEVICE_SYNC_FLAG,
60 )
61
62
63@comin.register_callback(comin.EP_ATM_WRITE_OUTPUT_BEFORE)
64def foo():
65 comin.print_info(f"{ta.__cuda_array_interface__=}")
66 ta_arr = np.asarray(ta)
67 if hasattr(ta_arr, "__cuda_array_interface__"):
68 comin.print_info(f"{ta_arr.__cuda_array_interface__=}")
69 comin.print_info(f"{type(ta_arr)=}")
70 if hasattr(ta_arr, "device"):
71 comin.print_info(f"{ta_arr.device=}")
72 comin.print_info(f"{ta_arr.base}")
73 comin.print_info("Computing mean surface temperture (on this process)")
74 tas = ta_arr[:, -1, :, 0, 0]
75 comin.print_info(f"{tas.mean()=}")
76
77
78@comin.register_callback(comin.EP_ATM_NUDGING_BEFORE)
80 device_to_host_xp = xp.asarray(device_to_host1)
81 device_to_host_xp[:] = 42.0
82 host_to_device_np = np.asarray(host_to_device1)
83 host_to_device_np[:] = 43.0
84
85
86@comin.register_callback(comin.EP_ATM_NUDGING_AFTER)
88 device_to_host_np = np.asarray(device_to_host2)
89 assert np.allclose(device_to_host_np, 42.0)
90 comin.print_info("check successful for device_to_host")
91 host_to_device_xp = xp.asarray(host_to_device2)
92 assert xp.allclose(host_to_device_xp, 43.0)
93 comin.print_info("check successful for host_to_device")
var_get(context, var_descriptor, flag)
get variable object, arguments: [entry point], (name string, domain id), access flag)
Definition comin.py:107
descrdata_get_domain(jg)
returns descriptive data for a given domain, arguments: jg
Definition comin.py:167
descrdata_get_global()
returns global descriptive data object
Definition comin.py:172
sec_ctor()
Definition gpu_test.py:38
set_to_42()
Definition gpu_test.py:79
print_element()
Definition gpu_test.py:87