diff --git a/test/test_generator.py b/test/test_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..1cc119b8468911f9a4bfeea6b139771a20cffeea --- /dev/null +++ b/test/test_generator.py @@ -0,0 +1,35 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch_npu + +from torch_npu.testing.testcase import TestCase, run_tests + + +class GeneratorTest(TestCase): + + def test_state(self): + gen = torch_npu._C.Generator() + gen.set_state(torch.get_rng_state()) + self.assertEqual(gen.get_state(), torch.get_rng_state()) + + def test_seed(self): + gen = torch_npu._C.Generator(torch_npu.npu.current_device()) + gen.manual_seed(1234) + self.assertEqual(gen.initial_seed(), 1234) + + +if __name__ == '__main__': + run_tests() \ No newline at end of file diff --git a/torch_npu/__init__.py b/torch_npu/__init__.py index c49a9472a447a8125ed0f05fb759a3a468e31cec..f062a69efa9ea6fe98ffe3cd899165901463563b 100644 --- a/torch_npu/__init__.py +++ b/torch_npu/__init__.py @@ -50,6 +50,7 @@ all_monkey_patches = [ ["nn.parallel.distributed._get_default_group", torch_npu.distributed.distributed_c10d._get_default_group], ["nn.functional", npu_functional], ["nn", npu_modules], + ["_C.Generator", torch_npu._C.Generator] ] all_monkey_patches += serialization_patches diff --git a/torch_npu/csrc/Generator.cpp b/torch_npu/csrc/Generator.cpp new file mode 100644 index 0000000000000000000000000000000000000000..964bf4a068c6583e8af2a38f88adb1a9bf819edb --- /dev/null +++ b/torch_npu/csrc/Generator.cpp @@ -0,0 +1,229 @@ +#include + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "torch_npu/csrc/aten/NPUGeneratorImpl.h" +#include "torch_npu/csrc/Generator.h" + +using namespace at; +using namespace torch; + +PyObject *THPGeneratorClass = nullptr; + +PyObject * initDefaultGenerator(at::Generator cdata) +{ + auto type = (PyTypeObject*)THPGeneratorClass; + auto self = THPObjectPtr{type->tp_alloc(type, 0)}; + if (!self) throw python_error(); + auto self_ = reinterpret_cast(self.get()); + self_->cdata = cdata; + return self.release(); +} + + +static void THPGenerator_dealloc(PyObject* _self) +{ + auto self = reinterpret_cast(_self); + if (self->cdata.defined()) { + self->cdata.set_pyobj(nullptr); + self->cdata.~Generator(); + } + Py_TYPE(_self)->tp_free(_self); +} + +static PyObject * THPGenerator_pynew(PyTypeObject *type, PyObject *args, PyObject *kwargs) +{ + HANDLE_TH_ERRORS + static torch::PythonArgParser parser({ + "Generator(Device device=None)" + }); + torch::ParsedArgs<1> parsed_args; + auto r = parser.parse(args, kwargs, parsed_args); + auto device = r.deviceWithDefault(0, at::Device(at::kCPU)); + + THPGeneratorPtr self((THPGenerator *)type->tp_alloc(type, 0)); + + if (device.type() == at::kCPU) { + self->cdata = at::make_generator(); + } else if (device.type() == at::kNPU){ + self->cdata = at::make_generator(device.index()); + } else { + AT_ERROR("Device type ", c10::DeviceTypeName(device.type()), + " is not supported for torch.Generator() api."); + } + return (PyObject*)self.release(); + END_HANDLE_TH_ERRORS +} + +static PyObject * THPGenerator_getState(PyObject *_self, PyObject *noargs) +{ + using namespace torch::autograd; + HANDLE_TH_ERRORS + auto& gen = ((THPGenerator*)_self)->cdata; + + // See Note [Acquire lock when using random generators] + std::lock_guard lock(gen.mutex()); + auto state_tensor = gen.get_state(); + + return THPVariable_Wrap(std::move(state_tensor)); + END_HANDLE_TH_ERRORS +} + +static PyObject * THPGenerator_setState(PyObject *_self, PyObject *_new_state) +{ + using namespace torch::autograd; + + HANDLE_TH_ERRORS + if (!THPVariable_Check(_new_state)) { + throw torch::TypeError("expected a torch.ByteTensor, but got %s", Py_TYPE(_new_state)->tp_name); + } + auto self = (THPGenerator*)_self; + auto& gen = self->cdata; + auto& new_state_tensor = ((THPVariable*)_new_state)->cdata; + + // See Note [Acquire lock when using random generators] + std::lock_guard lock(gen.mutex()); + gen.set_state(new_state_tensor); + + Py_INCREF(self); + return (PyObject*)self; + END_HANDLE_TH_ERRORS +} + +static PyObject * THPGenerator_manualSeed(PyObject *_self, PyObject *seed) +{ + HANDLE_TH_ERRORS + auto self = (THPGenerator*)_self; + auto generator = self->cdata; + THPUtils_assert(THPUtils_checkLong(seed), "manual_seed expected a long, " + "but got %s", THPUtils_typename(seed)); + // See Note [Acquire lock when using random generators] + std::lock_guard lock(generator.mutex()); + uint64_t seed_unpacked; + try { + // First try to interpret as unsigned long + seed_unpacked = THPUtils_unpackUInt64(seed); + } catch(...) { + if (PyErr_ExceptionMatches(PyExc_OverflowError)) { + // If an overflow happened, then the seed could be negative, + // so try to interpret it as signed long + PyErr_Clear(); + int64_t seed_unpacked_signed = THPUtils_unpackLong(seed); + seed_unpacked = *(reinterpret_cast(&seed_unpacked_signed)); + } else { + // If any other type of exception happened, rethrow it + throw; + } + } + generator.set_current_seed(seed_unpacked); + Py_INCREF(self); + return (PyObject*)self; + END_HANDLE_TH_ERRORS +} + +static PyObject * THPGenerator_seed(PyObject *_self, PyObject *noargs) +{ + HANDLE_TH_ERRORS + // See Note [Acquire lock when using random generators] + auto self = (THPGenerator*)_self; + std::lock_guard lock(self->cdata.mutex()); + uint64_t seed_val = self->cdata.seed(); + return THPUtils_packUInt64(seed_val); + END_HANDLE_TH_ERRORS +} + +static PyObject * THPGenerator_initialSeed(PyObject *_self, PyObject *noargs) +{ + HANDLE_TH_ERRORS + auto self = (THPGenerator*)_self; + return THPUtils_packUInt64(self->cdata.current_seed()); + END_HANDLE_TH_ERRORS +} + +static PyObject * THPGenerator_get_device(THPGenerator *self, void *unused) { + HANDLE_TH_ERRORS + return THPDevice_New(self->cdata.device()); + END_HANDLE_TH_ERRORS +} + +static struct PyGetSetDef THPGenerator_properties[] = { + {"device", (getter)THPGenerator_get_device, nullptr, nullptr, nullptr}, + {nullptr} +}; + +static PyMethodDef THPGenerator_methods[] = { + {"get_state", THPGenerator_getState, METH_NOARGS, nullptr}, + {"set_state", THPGenerator_setState, METH_O, nullptr}, + {"manual_seed", THPGenerator_manualSeed, METH_O, nullptr}, + {"seed", THPGenerator_seed, METH_NOARGS, nullptr}, + {"initial_seed", THPGenerator_initialSeed, METH_NOARGS, nullptr}, + {nullptr} +}; + +static struct PyMemberDef THPGenerator_members[] = { + {(char*)"_cdata", T_ULONGLONG, offsetof(THPGenerator, cdata), READONLY, nullptr}, + {nullptr} +}; + +PyTypeObject THPGeneratorType = { + PyVarObject_HEAD_INIT(nullptr, 0) + "torch_npu._C.Generator", /* tp_name */ + sizeof(THPGenerator), /* tp_basicsize */ + 0, /* tp_itemsize */ + THPGenerator_dealloc, /* tp_dealloc */ + 0, /* tp_vectorcall_offset */ + nullptr, /* tp_getattr */ + nullptr, /* tp_setattr */ + nullptr, /* tp_reserved */ + nullptr, /* tp_repr */ + nullptr, /* tp_as_number */ + nullptr, /* tp_as_sequence */ + nullptr, /* tp_as_mapping */ + nullptr, /* tp_hash */ + nullptr, /* tp_call */ + nullptr, /* tp_str */ + nullptr, /* tp_getattro */ + nullptr, /* tp_setattro */ + nullptr, /* tp_as_buffer */ + Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */ + nullptr, /* tp_doc */ + nullptr, /* tp_traverse */ + nullptr, /* tp_clear */ + nullptr, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + nullptr, /* tp_iter */ + nullptr, /* tp_iternext */ + THPGenerator_methods, /* tp_methods */ + THPGenerator_members, /* tp_members */ + THPGenerator_properties, /* tp_getset */ + nullptr, /* tp_base */ + nullptr, /* tp_dict */ + nullptr, /* tp_descr_get */ + nullptr, /* tp_descr_set */ + 0, /* tp_dictoffset */ + nullptr, /* tp_init */ + nullptr, /* tp_alloc */ + THPGenerator_pynew, /* tp_new */ +}; + +bool THPGenerator_init(PyObject *module) +{ + THPGeneratorClass = (PyObject*)&THPGeneratorType; + if (PyType_Ready(&THPGeneratorType) < 0) + return false; + Py_INCREF(&THPGeneratorType); + PyModule_AddObject(module, "Generator", (PyObject *)&THPGeneratorType); + return true; +} diff --git a/torch_npu/csrc/Generator.h b/torch_npu/csrc/Generator.h new file mode 100644 index 0000000000000000000000000000000000000000..ee8ca3dbf59c75e858fe9416bf48816fc727f75b --- /dev/null +++ b/torch_npu/csrc/Generator.h @@ -0,0 +1,7 @@ +#pragma once + +#include + + +bool THPGenerator_init(PyObject *module); + diff --git a/torch_npu/csrc/InitNpuBindings.cpp b/torch_npu/csrc/InitNpuBindings.cpp index 721d70e1467eeab8204b94f38e9a6a169f4f9fb6..35f38e360abc9ee22909c8939219f08bfeef3934 100644 --- a/torch_npu/csrc/InitNpuBindings.cpp +++ b/torch_npu/csrc/InitNpuBindings.cpp @@ -16,8 +16,9 @@ #include #include #include -#include "torch_npu/csrc/npu/Event.h" +#include +#include "torch_npu/csrc/npu/Event.h" #include "torch_npu/csrc/core/npu/NPUCachingAllocator.h" #include "torch_npu/csrc/framework/graph/execute/GraphExecutor.h" #include @@ -26,9 +27,11 @@ #include "torch_npu/csrc/core/npu/THNPUCachingHostAllocator.h" #include "torch_npu/csrc/distributed/Init.h" #include "torch_npu/csrc/profiler/init.h" +#include "torch_npu/csrc/Generator.h" PyObject* module; + void AddPyMethodDefs(std::vector& vector, PyMethodDef* methods) { if (!vector.empty()) { @@ -75,6 +78,7 @@ static PyMethodDef TorchNpuMethods[] = { void THNPStream_init(PyObject *module); void THNPEvent_init(PyObject *module); +bool THPGenerator_init(PyObject *module); PyMethodDef* THNPModule_get_methods(); namespace torch_npu { namespace autograd { @@ -109,6 +113,7 @@ PyObject* initModule(){ // C, so these lines have to execute first).. THNPStream_init(module); THNPEvent_init(module); + THPGenerator_init(module); torch_npu::autograd::initTorchFunctions(module); diff --git a/torch_npu/csrc/aten/NPUGeneratorImpl.cpp b/torch_npu/csrc/aten/NPUGeneratorImpl.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1820d3178fbf5c0b5cfaa6711c6eef9763c6cdad --- /dev/null +++ b/torch_npu/csrc/aten/NPUGeneratorImpl.cpp @@ -0,0 +1,335 @@ +#include +#include +#include + +#include "torch_npu/csrc/aten/NPUGeneratorImpl.h" + +namespace at_npu { +namespace detail { + +namespace { + +// Ensures we only call npuGetDeviceCount only once. +static std::once_flag num_npu_init_flag; + +// Total number of npus in the system. +static int64_t num_npus; + +// Ensures default_gens_npu is initialized once. +static std::deque npu_gens_init_flag; + +// Default, global NPU generators, one per NPU. +static std::vector default_gens_npu; + +/* +* Populates the global variables related to NPU generators +* Warning: this function must only be called once! +*/ +static void initNPUGenVector(){ + num_npus = c10::npu::device_count(); + npu_gens_init_flag.resize(num_npus); + default_gens_npu.resize(num_npus); +} + +} // anonymous namespace + +/** + * PyTorch maintains a collection of default generators that get + * initialized once. The purpose of these default generators is to + * maintain a global running state of the pseudo random number generation, + * when a user does not explicitly mention any generator. + * getDefaultNPUGenerator gets the default generator for a particular + * NPU device. + */ +const at::Generator& getDefaultNPUGenerator(c10::DeviceIndex device_index) { + std::call_once(num_npu_init_flag, initNPUGenVector); + c10::DeviceIndex idx = device_index; + if (idx == -1) { + idx = c10::npu::current_device(); + } else { + TORCH_CHECK(idx >= 0 && idx < num_npus); + } + std::call_once(npu_gens_init_flag[idx], [&] { + default_gens_npu[idx] = at::make_generator(idx); + default_gens_npu[idx].seed(); + }); + return default_gens_npu[idx]; +} + +/** + * Utility to create a NPUGeneratorImpl. Returns a shared_ptr + */ +at::Generator createNPUGenerator(c10::DeviceIndex device_index) { + std::call_once(num_npu_init_flag, initNPUGenVector); + c10::DeviceIndex idx = device_index; + if (idx == -1) { + idx = c10::npu::current_device(); + } + TORCH_CHECK(idx >= 0 && idx < num_npus, "The device_index is invalid."); + auto gen = at::make_generator(idx); + auto npu_gen = at::check_generator(gen); + npu_gen->set_current_seed(c10::default_rng_seed_val); + npu_gen->set_philox_offset_per_thread(0); + return gen; +} + +} // namespace detail + +/** + * Note [Why enforce RNG offset % 4 == 0?] + * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + * Curand philox does allow offsets that aren't a multiple of 4. + * But jit kernels don't use curand, they use a custom "Philox" class (see + * torch/csrc/jit/tensorexpr/npu_random.h or + * torch/csrc/jit/codegen/npu/runtime/random_numbers.cu). + * The "Philox" constructor computes offset/4 (a uint64_t division) to locate its + * internal start in its virtual bitstream viewed as 128-bit chunks, then, when called + * in a thread, returns one 32-bit chunk at a time from that start in the bitstream. + * In other words, if the incoming offset is not a multiple of 4, each thread + * might repeat some previously-generated 32-bit values in the bitstream. See + * https://github.com/pytorch/pytorch/pull/50169. + */ + +/** + * NPUGeneratorImpl class implementation + */ +NPUGeneratorImpl::NPUGeneratorImpl(c10::DeviceIndex device_index) + : c10::GeneratorImpl{c10::Device(c10::DeviceType::NPU, device_index), + c10::DispatchKeySet(c10::DispatchKey::NPU)} { + //at::npu::assertNotCapturing("Cannot construct a new NPUGeneratorImpl"); +} + +/** + * Sets the seed to be used by curandStatePhilox4_32_10 + * Resets the philox_offset_per_thread_ to 0 + * + * See Note [Acquire lock when using random generators] + */ +void NPUGeneratorImpl::set_current_seed(uint64_t seed) { + seed_ = seed; + philox_offset_per_thread_ = 0; +} + +#define CAPTURE_DEFAULT_GENS_MSG \ +"In regions captured by NPU graphs, you may only use the default NPU RNG " \ +"generator on the device that's current when capture begins. " \ +"If you need a non-default (user-supplied) generator, or a generator on another " \ +"device, please file an issue." + +/** + * Gets the current seed of NPUGeneratorImpl. + */ +uint64_t NPUGeneratorImpl::current_seed() const { + // Debatable if current_seed() should be allowed in captured regions. + // Conservatively disallow it for now. + return seed_; +} + +/** + * Gets a nondeterministic random number from /dev/urandom or time, + * seeds the CPUGeneratorImpl with it and then returns that number. + * + * You can move this function to Generator.cpp if the algorithm + * in getNonDeterministicRandom is unified for both CPU and NPU + */ +uint64_t NPUGeneratorImpl::seed() { + auto random = c10::detail::getNonDeterministicRandom(true); + this->set_current_seed(random); + return random; +} + +/** + * Gets the current internal state of NpuGeneratorImpl. The internal + * state is returned as a CPU byte tensor. + */ +c10::intrusive_ptr NPUGeneratorImpl::get_state() const { + // The RNG state comprises the seed, and an offset used for Philox. + // The following line is just here for BC reason. sizeof curandStateMtgp32 is 4120. + // It used to be static const size_t states_size = MAX_NUM_BLOCKS * sizeof(curandStateMtgp32); + // MAX_NUM_BLOCKS was 200 and sizeof(curandStateMtgp32) is 4120. Hardcoding these numbers here + // because this is just host side code and we don't want to worry about linking with npu + static const size_t states_size = 200 * sizeof(4120); + static const size_t seed_size = sizeof(uint64_t); + static const size_t offset_size = sizeof(int64_t); + static const size_t total_size = states_size + seed_size + offset_size; + + auto state_tensor = at::detail::empty_cpu({(int64_t)total_size}, at::ScalarType::Byte, + c10::nullopt, c10::nullopt, c10::nullopt, c10::nullopt); + auto rng_state = state_tensor.data_ptr(); + // since curandStateMTGP is not used anymore, fill gen_states of THCGenerator with deterministic garbage value of -1 + // gen_states in THCGenerator struct was an array of curandStateMtgp32s. + memset(rng_state, -1, states_size); + auto current_seed = this->current_seed(); + auto offset = static_cast(this->philox_offset_per_thread()); // Note that old THCGeneratorState had offset as std::atomic + memcpy(rng_state + states_size, ¤t_seed, seed_size); + memcpy(rng_state + states_size + seed_size, &offset, offset_size); + + return state_tensor.getIntrusivePtr(); +} + +/** + * Sets the internal state of NPUGeneratorImpl. The new internal state + * must be a strided CPU byte tensor and have appropriate size. See + * comments of NPUGeneratorImpl::state for information about the layout + * and size of the internal state. + */ +void NPUGeneratorImpl::set_state(const c10::TensorImpl& new_state) { + static const size_t states_size = 200 * sizeof(4120); // this line is just here for BC reason + static const size_t seed_size = sizeof(uint64_t); + static const size_t offset_size = sizeof(int64_t); + static const size_t total_size = states_size + seed_size + offset_size; + + at::detail::check_rng_state(new_state); + + bool no_philox_seed = false; + auto new_state_size = new_state.numel(); + if (new_state_size == total_size - offset_size) { + no_philox_seed = true; + } else { + TORCH_CHECK(new_state_size == total_size, "RNG state is wrong size"); + } + + uint64_t input_seed; + auto new_rng_state = new_state.data(); + memcpy(&input_seed, new_rng_state + states_size, seed_size); + this->set_current_seed(input_seed); + int64_t philox_offset = 0; + if (!no_philox_seed) { + memcpy(&philox_offset, new_rng_state + states_size + seed_size, offset_size); + } + this->set_philox_offset_per_thread(static_cast(philox_offset)); +} + +/** + * Sets the philox_offset_per_thread_ to be used by curandStatePhilox4_32_10 + * + * See Note [Acquire lock when using random generators] + */ +void NPUGeneratorImpl::set_philox_offset_per_thread(uint64_t offset) { + // see Note [Why enforce RNG offset % 4 == 0?] + TORCH_CHECK(offset % 4 == 0, "offset must be a multiple of 4"); + philox_offset_per_thread_ = offset; +} + +/** + * Gets the current philox_offset_per_thread_ of NpuGeneratorImpl. + */ +uint64_t NPUGeneratorImpl::philox_offset_per_thread() const { + return philox_offset_per_thread_; +} + +/** + * Called by NpuGraph to prepare this instance for a graph capture region. + * offset_extragraph is the initial offset at the start of the graphed region. + * offset_intragraph tracks the offset in the graphed region. + */ +void NPUGeneratorImpl::capture_prologue(int64_t* offset_extragraph) { + offset_extragraph_ = offset_extragraph; + offset_intragraph_ = 0; + graph_expects_this_gen_ = true; +} + +/** + * Called by NpuGraph to finalize a graph capture region for this instance. + */ +uint64_t NPUGeneratorImpl::capture_epilogue() { + graph_expects_this_gen_ = false; + return offset_intragraph_; +} + +/** + * Gets the seed and philox offset value to be used in + * curandStatePhilox4_32_10, in an opaque PhiloxNpuState that's safe + * and can be used non-divergently in callers whether NPU graph + * capture is underway or not. See + * Note [NPU Graph-safe RNG states] + * + * Each kernel using philox has to sensibly increment offset + * for future users of philox. So it gets the "old" value for + * itself (before add), and tells subsequent users which offset + * they should use, since only the kernel knows how many randoms + * it intends to generate. + * + * Increment should be at least the number of curand() random numbers used in + * each thread. It is the user's responsibility to make sure the increment + * for philox is never smaller than the number of curand() calls. Increment + * value > the number of curand() calls won't harm but anything less would mean + * that you would be reusing random values from previous calls. + * + * See Note [Acquire lock when using random generators] + */ +PhiloxNpuState NPUGeneratorImpl::philox_npu_state(uint64_t increment) { + // rounds increment up to the nearest multiple of 4 + increment = ((increment + 3) / 4) * 4; + /* + if (at::npu::currentStreamCaptureStatus() != at::npu::CaptureStatus::None) { + TORCH_CHECK(graph_expects_this_gen_, + "philox_npu_state for an unexpected NPU generator used during capture. " + CAPTURE_DEFAULT_GENS_MSG); + // see Note [Why enforce RNG offset % 4 == 0?] + TORCH_INTERNAL_ASSERT(this->offset_intragraph_ % 4 == 0); + uint32_t offset = this->offset_intragraph_; + TORCH_INTERNAL_ASSERT(this->offset_intragraph_ <= + std::numeric_limits::max() - increment); + this->offset_intragraph_ += increment; + return PhiloxNpuState(this->seed_, + this->offset_extragraph_, + offset); + } else { + TORCH_CHECK(!graph_expects_this_gen_, + "NPU generator expects graph capture to be underway, " + "but the current stream is not capturing."); + // see Note [Why enforce RNG offset % 4 == 0?] + TORCH_INTERNAL_ASSERT(this->philox_offset_per_thread_ % 4 == 0); + uint64_t offset = this->philox_offset_per_thread_; + this->philox_offset_per_thread_ += increment; + return PhiloxNpuState(this->seed_, offset); + }*/ + + return PhiloxNpuState(this->seed_, 0); +} + +/** + * Temporarily accommodates call sites that use philox_engine_inputs. + * Allows incremental refactor of call sites to use philox_npu_state. + */ +std::pair NPUGeneratorImpl::philox_engine_inputs(uint64_t increment) { + // rounds increment up to the nearest multiple of 4 + increment = ((increment + 3) / 4) * 4; + // see Note [Why enforce RNG offset % 4 == 0?] + TORCH_INTERNAL_ASSERT(this->philox_offset_per_thread_ % 4 == 0); + uint64_t offset = this->philox_offset_per_thread_; + this->philox_offset_per_thread_ += increment; + return std::make_pair(this->seed_, offset); +} + +/* + * Gets the DeviceType of NPUGeneratorImpl. + * Used for type checking during run time. + */ +c10::DeviceType NPUGeneratorImpl::device_type() { + return c10::DeviceType::NPU; +} + +/** + * Public clone method implementation + * + * See Note [Acquire lock when using random generators] + */ +std::shared_ptr NPUGeneratorImpl::clone() const { + return std::shared_ptr(this->clone_impl()); +} + +/** + * Private clone method implementation + * + * See Note [Acquire lock when using random generators] + */ +NPUGeneratorImpl* NPUGeneratorImpl::clone_impl() const { + auto gen = new NPUGeneratorImpl(this->device().index()); + gen->set_current_seed(this->seed_); + gen->set_philox_offset_per_thread(this->philox_offset_per_thread_); + return gen; +} + +} // namespace at_npu diff --git a/torch_npu/csrc/aten/NPUGeneratorImpl.h b/torch_npu/csrc/aten/NPUGeneratorImpl.h new file mode 100644 index 0000000000000000000000000000000000000000..569a9a3c202d8015cc71fce66072a448e5a00b9f --- /dev/null +++ b/torch_npu/csrc/aten/NPUGeneratorImpl.h @@ -0,0 +1,159 @@ +#pragma once + +#include +#include +#include +#include +#include + + +namespace at_npu { +/** + * Note [NPU Graph-safe RNG states] + * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + * + * Strategy: + * ~~~~~~~~~ + * A NPU graph containing multiple RNG ops behaves like a + * single giant kernel from the perspective of ops external + * to the graph. During graph capture, logic below records + * the total of all offset increments that occur in the graphed + * region, and records the final total as the offset for the + * entire graph. + * + * When the graph reruns, the logic that reruns it + * increments this device's NPU generator's offset + * by that total. + * + * Meanwhile, within the graph, at capture time, instead of + * populating PhiloxNpuStates with the uint64_t offset pulled + * directly from the global state, PhiloNpuState instead + * holds a pointer to one-element stream-local int64_t device tensor + * holding an initial offset value, and a uint64_t holding an + * intra-graph offset. (The intra-graph offset starts from zero + * when capture begins.) In each consumer kernel, + * at::npu::philox::unpack computes the offset to use for this kernel + * as intra-graph offset + *initial offset. + * + * When the graph reruns, the logic that reruns it first + * fill_s the initial offset tensor with this device's + * NPU generator's current offset. + * + * The control flow above ensures graphed execution is bitwise + * identical to eager execution as long as RNG ops are enqueued + * from a single thread, even if RNG ops and graphs containing + * RNG ops are enqueued and run simultaneously on multiple streams. + * + * Usage: + * ~~~~~~ + * PhiloxNPUState in this file, and unpack() in + * npu/NPUGraphsUtils.cuh allow non-divergent use of + * NPUGeneratorImpl whether graph capture is underway or not. + * + * Each PhiloxNpuState instance should be used for one and only one + * consumer kernel. + * + * Example (see e.g. native/npu/Dropout.cu): + * + * #include + * #include + * + * __global__ void kernel(..., PhiloxnpuState philox_args) { + * auto seeds = at::npu::philox::unpack(philox_args); + * IndexType idx = blockIdx.x * blockDim.x + threadIdx.x; + * curandStatePhilox4_32_10_t state; + * curand_init(std::get<0>(seeds), // seed + * idx, // per-thread subsequence + * std::get<1>(seeds), // offset in subsequence + * &state); + * ... + * } + * + * host_caller(...) { + * PhiloxnpuState rng_engine_inputs; + * { + * // See Note [Acquire lock when using random generators] + * std::lock_guard lock(gen->mutex_); + * + * // gen could be HostState or DevState here! No divergent code needed! + * rng_engine_inputs = gen->philox_npu_state(offset_increment); + * } + * kernel<<<...>>>(..., rng_engine_inputs); + * } + * + */ + + +// Stores state values. Passed as a kernel argument. See "Usage:" above. +struct PhiloxNpuState { + PhiloxNpuState() = default; + PhiloxNpuState(const PhiloxNpuState&) = default; + // Called if graph capture is not underway + PhiloxNpuState(uint64_t seed, + uint64_t offset) { + seed_ = seed; + offset_.val = offset; + } + // Called if graph capture is underway + PhiloxNpuState(uint64_t seed, + int64_t* offset_extragraph, + uint32_t offset_intragraph) { + seed_ = seed; + offset_.ptr = offset_extragraph; + offset_intragraph_ = offset_intragraph; + captured_ = true; + } + + // Public members, directly accessible by at::Npu::philox::unpack. + // If we made them private with getters/setters, the getters/setters + // would have to be __device__, and we can't declare __device__ in ATen. + union Payload { + uint64_t val; + int64_t* ptr; + }; + + uint64_t seed_; + Payload offset_; + uint32_t offset_intragraph_; + bool captured_ = false; +}; + +struct C10_EXPORT NPUGeneratorImpl : public c10::GeneratorImpl { + // Constructors + NPUGeneratorImpl(c10::DeviceIndex device_index = -1); + ~NPUGeneratorImpl() = default; + + // NPUGeneratorImpl methods + std::shared_ptr clone() const; + void set_current_seed(uint64_t seed) override; + uint64_t current_seed() const override; + uint64_t seed() override; + void set_state(const c10::TensorImpl& new_state) override; + c10::intrusive_ptr get_state() const override; + void set_philox_offset_per_thread(uint64_t offset); + uint64_t philox_offset_per_thread() const; + void capture_prologue(int64_t* offset_extragraph); + uint64_t capture_epilogue(); + PhiloxNpuState philox_npu_state(uint64_t increment); + + // Temporarily accommodates call sites that use philox_engine_inputs. + // Allows incremental refactor of call sites to use philox_npu_state. + std::pair philox_engine_inputs(uint64_t increment); + static c10::DeviceType device_type(); + +private: + NPUGeneratorImpl* clone_impl() const override; + uint64_t seed_ = c10::default_rng_seed_val; + uint64_t philox_offset_per_thread_ = 0; + int64_t* offset_extragraph_; + uint32_t offset_intragraph_ = 0; + bool graph_expects_this_gen_ = false; +}; + +namespace detail { +const at::Generator& getDefaultNPUGenerator( + c10::DeviceIndex device_index = -1); +at::Generator createNPUGenerator(c10::DeviceIndex device_index = -1); + +} // namespace detail +} // namespace at_npu diff --git a/torch_npu/csrc/aten/ops/DropoutWithByteMaskKernelNpu.cpp b/torch_npu/csrc/aten/ops/DropoutWithByteMaskKernelNpu.cpp index 7c0d7b42d300abb4e49cab1dc90d51e098e6bd49..bfddbe726db0b84c0e7a47fded457f5523fe6b83 100644 --- a/torch_npu/csrc/aten/ops/DropoutWithByteMaskKernelNpu.cpp +++ b/torch_npu/csrc/aten/ops/DropoutWithByteMaskKernelNpu.cpp @@ -22,7 +22,7 @@ #include "torch_npu/csrc/core/npu/NPUCachingAllocator.h" #include "torch_npu/csrc/aten/NPUNativeFunctions.h" -#include +#include "torch_npu/csrc/aten/NPUGeneratorImpl.h" namespace at_npu { namespace native { @@ -57,7 +57,7 @@ at::Tensor dropout_gen_byte_mask(const at::Tensor& self, at::Scalar prob) { // 127~64 63~0 // so, we set seed2 = 0 to ensure the seed which user set is equal to the seed // used by the operator DropOutGenMaskV3 - const auto gen = at::npu::detail::getDefaultNPUGenerator(); + const auto gen = at_npu::detail::getDefaultNPUGenerator(); const int64_t seed = static_cast(gen.current_seed()); const int64_t seed2 = 0; cmd.Name("DropOutGenMaskV3") diff --git a/torch_npu/csrc/npu/Event.cpp b/torch_npu/csrc/npu/Event.cpp index d93870ce067cd681e472d6e71ae63907c3d44722..36489846cda28970174e3b6a834bac25a3a229f1 100644 --- a/torch_npu/csrc/npu/Event.cpp +++ b/torch_npu/csrc/npu/Event.cpp @@ -128,7 +128,7 @@ static PyMethodDef THNPEvent_methods[] = { PyTypeObject THNPEventType = { PyVarObject_HEAD_INIT(nullptr, 0) - "torch._C._NPUEventBase", /* tp_name */ + "torch_npu._C._NPUEventBase", /* tp_name */ sizeof(THNPEvent), /* tp_basicsize */ 0, /* tp_itemsize */ (destructor)THNPEvent_dealloc, /* tp_dealloc */ diff --git a/torch_npu/csrc/npu/Module.cpp b/torch_npu/csrc/npu/Module.cpp index cd344dde789892ed2643fd5b23e1f5ec902fc725..6c954bd2bf6e803db4bf3c033312f054bce1d846 100644 --- a/torch_npu/csrc/npu/Module.cpp +++ b/torch_npu/csrc/npu/Module.cpp @@ -14,7 +14,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include #include #include @@ -44,6 +43,8 @@ #include "torch_npu/csrc/profiler/e2e_profiler.h" #include "torch_npu/csrc/framework/graph/execute/GraphExecutor.h" #include "torch_npu/csrc/core/npu/NPURunMode.h" +#include "torch_npu/csrc/aten/NPUGeneratorImpl.h" + static PyObject* THNPModule_initExtension(PyObject* self, PyObject* noargs) { HANDLE_TH_ERRORS @@ -68,7 +69,7 @@ static PyObject* THNPModule_initExtension(PyObject* self, PyObject* noargs) { auto num_npus = c10::npu::device_count(); auto default_npu_generators = PyTuple_New(static_cast(num_npus)); for(int i = 0; i < num_npus; i++) { - auto gen = at::npu::detail::getDefaultNPUGenerator(i); + auto gen = at_npu::detail::getDefaultNPUGenerator(i); auto cast_gen = (THPGenerator*)THPGenerator_initDefaultGenerator(gen); // This reference is meant to be given away, so no need to incref here. PyTuple_SetItem(default_npu_generators, i, (PyObject*)cast_gen); diff --git a/torch_npu/csrc/npu/Stream.cpp b/torch_npu/csrc/npu/Stream.cpp index 04eabf1925d766116089521e0df589f17fde49d3..852e740e98eac524d0083da3ae632d5145adec56 100644 --- a/torch_npu/csrc/npu/Stream.cpp +++ b/torch_npu/csrc/npu/Stream.cpp @@ -134,7 +134,7 @@ static PyMethodDef THNPStream_methods[] = { PyTypeObject THNPStreamType = { PyVarObject_HEAD_INIT(nullptr, 0) - "torch._C._NPUStreamBase", /* tp_name */ + "torch_npu._C._NPUStreamBase", /* tp_name */ sizeof(THNPStream), /* tp_basicsize */ 0, /* tp_itemsize */ (destructor)THNPStream_dealloc, /* tp_dealloc */