diff --git a/data/datasets/README.md b/data/datasets/README.md index f40c2e824e94c038e186d9e1ffa149a8382a41e2..5995236cb987425db1c135cfb847f6b766372618 100644 --- a/data/datasets/README.md +++ b/data/datasets/README.md @@ -1 +1 @@ -# This is the default datasets location required by inference models +# This is the default datasets location required by inference models diff --git a/models/audio/speech_recognition/conformer/igie/requirements.txt b/models/audio/speech_recognition/conformer/igie/requirements.txt index 58a9d085c1edb8f5304ba5cd9349eea3d83537d6..2f7cd1f24262857100607eb19f6ccc14b7e98a31 100644 --- a/models/audio/speech_recognition/conformer/igie/requirements.txt +++ b/models/audio/speech_recognition/conformer/igie/requirements.txt @@ -1,4 +1,4 @@ -tqdm -onnx -typeguard==2.13.3 -onnxsim +tqdm +onnx +typeguard==2.13.3 +onnxsim diff --git a/models/audio/speech_recognition/conformer/igie/wenet/docs/make.bat b/models/audio/speech_recognition/conformer/igie/wenet/docs/make.bat index 16b063834966b9f8b0cb6bbc4f46ec4e74fa52ba..a42274a63310b8672adb4eb1bbd2c170cdc7684a 100644 --- a/models/audio/speech_recognition/conformer/igie/wenet/docs/make.bat +++ b/models/audio/speech_recognition/conformer/igie/wenet/docs/make.bat @@ -1,35 +1,35 @@ -@ECHO OFF - -pushd %~dp0 - -REM Command file for Sphinx documentation - -if "%SPHINXBUILD%" == "" ( - set SPHINXBUILD=sphinx-build -) -set SOURCEDIR=. -set BUILDDIR=_build - -if "%1" == "" goto help - -%SPHINXBUILD% >NUL 2>NUL -if errorlevel 9009 ( - echo. - echo.The 'sphinx-build' command was not found. Make sure you have Sphinx - echo.installed, then set the SPHINXBUILD environment variable to point - echo.to the full path of the 'sphinx-build' executable. Alternatively you - echo.may add the Sphinx directory to PATH. - echo. - echo.If you don't have Sphinx installed, grab it from - echo.http://sphinx-doc.org/ - exit /b 1 -) - -%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% -goto end - -:help -%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% - -:end -popd +@ECHO OFF + +pushd %~dp0 + +REM Command file for Sphinx documentation + +if "%SPHINXBUILD%" == "" ( + set SPHINXBUILD=sphinx-build +) +set SOURCEDIR=. +set BUILDDIR=_build + +if "%1" == "" goto help + +%SPHINXBUILD% >NUL 2>NUL +if errorlevel 9009 ( + echo. + echo.The 'sphinx-build' command was not found. Make sure you have Sphinx + echo.installed, then set the SPHINXBUILD environment variable to point + echo.to the full path of the 'sphinx-build' executable. Alternatively you + echo.may add the Sphinx directory to PATH. + echo. + echo.If you don't have Sphinx installed, grab it from + echo.http://sphinx-doc.org/ + exit /b 1 +) + +%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% +goto end + +:help +%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% + +:end +popd diff --git a/models/audio/speech_recognition/conformer/igie/wenet/runtime/android/app/src/main/cpp/patch/openfst/src/CMakeLists.txt b/models/audio/speech_recognition/conformer/igie/wenet/runtime/android/app/src/main/cpp/patch/openfst/src/CMakeLists.txt index 04fb1c15b6c1098b61f57f3c276d1a7595b20319..04051ef5ae46c04a40c1ffccc98c37fa594ad13e 100644 --- a/models/audio/speech_recognition/conformer/igie/wenet/runtime/android/app/src/main/cpp/patch/openfst/src/CMakeLists.txt +++ b/models/audio/speech_recognition/conformer/igie/wenet/runtime/android/app/src/main/cpp/patch/openfst/src/CMakeLists.txt @@ -1,23 +1,23 @@ - -#-DHAVE_CONFIG_H -I./../include -fno-exceptions -funsigned-char -std=c++11 -MT symbol-table.lo -MD -MP -MF .deps/symbol-table.Tpo -c symbol-table.cc -fno-common -DPIC -o .libs/symbol-table.o - -include_directories(./include/) -install(DIRECTORY include/ DESTINATION include/ - FILES_MATCHING PATTERN "*.h") - -add_subdirectory(lib) - -if(HAVE_SCRIPT) - add_subdirectory(script) -endif(HAVE_SCRIPT) - -if(HAVE_BIN) - add_subdirectory(bin) -endif(HAVE_BIN) - -add_subdirectory(extensions) - -if(BUILD_TESTING) - enable_testing() - add_subdirectory(test) -endif(BUILD_TESTING) + +#-DHAVE_CONFIG_H -I./../include -fno-exceptions -funsigned-char -std=c++11 -MT symbol-table.lo -MD -MP -MF .deps/symbol-table.Tpo -c symbol-table.cc -fno-common -DPIC -o .libs/symbol-table.o + +include_directories(./include/) +install(DIRECTORY include/ DESTINATION include/ + FILES_MATCHING PATTERN "*.h") + +add_subdirectory(lib) + +if(HAVE_SCRIPT) + add_subdirectory(script) +endif(HAVE_SCRIPT) + +if(HAVE_BIN) + add_subdirectory(bin) +endif(HAVE_BIN) + +add_subdirectory(extensions) + +if(BUILD_TESTING) + enable_testing() + add_subdirectory(test) +endif(BUILD_TESTING) diff --git a/models/audio/speech_recognition/conformer/igie/wenet/runtime/android/app/src/main/cpp/utils/thread_pool.h b/models/audio/speech_recognition/conformer/igie/wenet/runtime/android/app/src/main/cpp/utils/thread_pool.h index 03cc45b8b2e4bc273706353f248ef088db0eaca8..a78162995d90bf079ad091cf14cb9f2cd4476d05 100644 --- a/models/audio/speech_recognition/conformer/igie/wenet/runtime/android/app/src/main/cpp/utils/thread_pool.h +++ b/models/audio/speech_recognition/conformer/igie/wenet/runtime/android/app/src/main/cpp/utils/thread_pool.h @@ -1,113 +1,113 @@ -// Copyright (c) 2012 Jakob Progsch, Václav Zeman - -// This software is provided 'as-is', without any express or implied -// warranty. In no event will the authors be held liable for any damages -// arising from the use of this software. - -// Permission is granted to anyone to use this software for any purpose, -// including commercial applications, and to alter it and redistribute it -// freely, subject to the following restrictions: - -// 1. The origin of this software must not be misrepresented; you must not -// claim that you wrote the original software. If you use this software -// in a product, an acknowledgment in the product documentation would be -// appreciated but is not required. - -// 2. Altered source versions must be plainly marked as such, and must not be -// misrepresented as being the original software. - -// 3. This notice may not be removed or altered from any source -// distribution. - -#ifndef UTILS_THREAD_POOL_H_ -#define UTILS_THREAD_POOL_H_ - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -class ThreadPool { - public: - explicit ThreadPool(size_t); - template - auto enqueue(F&& f, Args&&... args) - -> std::future::type>; - ~ThreadPool(); - - private: - // need to keep track of threads so we can join them - std::vector workers; - // the task queue - std::queue > tasks; - - // synchronization - std::mutex queue_mutex; - std::condition_variable condition; - bool stop; -}; - -// the constructor just launches some amount of workers -inline ThreadPool::ThreadPool(size_t threads) : stop(false) { - for (size_t i = 0; i < threads; ++i) - workers.emplace_back([this] { - for (;;) { - std::function task; - - { - std::unique_lock lock(this->queue_mutex); - this->condition.wait( - lock, [this] { return this->stop || !this->tasks.empty(); }); - if (this->stop && this->tasks.empty()) return; - task = std::move(this->tasks.front()); - this->tasks.pop(); - } - - task(); - } - }); -} - -// add new work item to the pool -template -auto ThreadPool::enqueue(F&& f, Args&&... args) - -> std::future::type> { - using return_type = typename std::result_of::type; - - auto task = std::make_shared >( - std::bind(std::forward(f), std::forward(args)...)); - - std::future res = task->get_future(); - { - std::unique_lock lock(queue_mutex); - - // don't allow enqueueing after stopping the pool - if (stop) { - throw std::runtime_error("enqueue on stopped ThreadPool"); - } - - tasks.emplace([task]() { (*task)(); }); - } - condition.notify_one(); - return res; -} - -// the destructor joins all threads -inline ThreadPool::~ThreadPool() { - { - std::unique_lock lock(queue_mutex); - stop = true; - } - condition.notify_all(); - for (std::thread& worker : workers) { - worker.join(); - } -} - -#endif // UTILS_THREAD_POOL_H_ +// Copyright (c) 2012 Jakob Progsch, Václav Zeman + +// This software is provided 'as-is', without any express or implied +// warranty. In no event will the authors be held liable for any damages +// arising from the use of this software. + +// Permission is granted to anyone to use this software for any purpose, +// including commercial applications, and to alter it and redistribute it +// freely, subject to the following restrictions: + +// 1. The origin of this software must not be misrepresented; you must not +// claim that you wrote the original software. If you use this software +// in a product, an acknowledgment in the product documentation would be +// appreciated but is not required. + +// 2. Altered source versions must be plainly marked as such, and must not be +// misrepresented as being the original software. + +// 3. This notice may not be removed or altered from any source +// distribution. + +#ifndef UTILS_THREAD_POOL_H_ +#define UTILS_THREAD_POOL_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +class ThreadPool { + public: + explicit ThreadPool(size_t); + template + auto enqueue(F&& f, Args&&... args) + -> std::future::type>; + ~ThreadPool(); + + private: + // need to keep track of threads so we can join them + std::vector workers; + // the task queue + std::queue > tasks; + + // synchronization + std::mutex queue_mutex; + std::condition_variable condition; + bool stop; +}; + +// the constructor just launches some amount of workers +inline ThreadPool::ThreadPool(size_t threads) : stop(false) { + for (size_t i = 0; i < threads; ++i) + workers.emplace_back([this] { + for (;;) { + std::function task; + + { + std::unique_lock lock(this->queue_mutex); + this->condition.wait( + lock, [this] { return this->stop || !this->tasks.empty(); }); + if (this->stop && this->tasks.empty()) return; + task = std::move(this->tasks.front()); + this->tasks.pop(); + } + + task(); + } + }); +} + +// add new work item to the pool +template +auto ThreadPool::enqueue(F&& f, Args&&... args) + -> std::future::type> { + using return_type = typename std::result_of::type; + + auto task = std::make_shared >( + std::bind(std::forward(f), std::forward(args)...)); + + std::future res = task->get_future(); + { + std::unique_lock lock(queue_mutex); + + // don't allow enqueueing after stopping the pool + if (stop) { + throw std::runtime_error("enqueue on stopped ThreadPool"); + } + + tasks.emplace([task]() { (*task)(); }); + } + condition.notify_one(); + return res; +} + +// the destructor joins all threads +inline ThreadPool::~ThreadPool() { + { + std::unique_lock lock(queue_mutex); + stop = true; + } + condition.notify_all(); + for (std::thread& worker : workers) { + worker.join(); + } +} + +#endif // UTILS_THREAD_POOL_H_ diff --git a/models/audio/speech_recognition/conformer/igie/wenet/runtime/android/gradlew.bat b/models/audio/speech_recognition/conformer/igie/wenet/runtime/android/gradlew.bat index e95643d6a2ca62258464e83c72f5156dc941c609..f9553162f122c71b34635112e717c3e733b5b212 100644 --- a/models/audio/speech_recognition/conformer/igie/wenet/runtime/android/gradlew.bat +++ b/models/audio/speech_recognition/conformer/igie/wenet/runtime/android/gradlew.bat @@ -1,84 +1,84 @@ -@if "%DEBUG%" == "" @echo off -@rem ########################################################################## -@rem -@rem Gradle startup script for Windows -@rem -@rem ########################################################################## - -@rem Set local scope for the variables with windows NT shell -if "%OS%"=="Windows_NT" setlocal - -set DIRNAME=%~dp0 -if "%DIRNAME%" == "" set DIRNAME=. -set APP_BASE_NAME=%~n0 -set APP_HOME=%DIRNAME% - -@rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. -set DEFAULT_JVM_OPTS= - -@rem Find java.exe -if defined JAVA_HOME goto findJavaFromJavaHome - -set JAVA_EXE=java.exe -%JAVA_EXE% -version >NUL 2>&1 -if "%ERRORLEVEL%" == "0" goto init - -echo. -echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. -echo. -echo Please set the JAVA_HOME variable in your environment to match the -echo location of your Java installation. - -goto fail - -:findJavaFromJavaHome -set JAVA_HOME=%JAVA_HOME:"=% -set JAVA_EXE=%JAVA_HOME%/bin/java.exe - -if exist "%JAVA_EXE%" goto init - -echo. -echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME% -echo. -echo Please set the JAVA_HOME variable in your environment to match the -echo location of your Java installation. - -goto fail - -:init -@rem Get command-line arguments, handling Windows variants - -if not "%OS%" == "Windows_NT" goto win9xME_args - -:win9xME_args -@rem Slurp the command line arguments. -set CMD_LINE_ARGS= -set _SKIP=2 - -:win9xME_args_slurp -if "x%~1" == "x" goto execute - -set CMD_LINE_ARGS=%* - -:execute -@rem Setup the command line - -set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar - -@rem Execute Gradle -"%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %CMD_LINE_ARGS% - -:end -@rem End local scope for the variables with windows NT shell -if "%ERRORLEVEL%"=="0" goto mainEnd - -:fail -rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of -rem the _cmd.exe /c_ return code! -if not "" == "%GRADLE_EXIT_CONSOLE%" exit 1 -exit /b 1 - -:mainEnd -if "%OS%"=="Windows_NT" endlocal - -:omega +@if "%DEBUG%" == "" @echo off +@rem ########################################################################## +@rem +@rem Gradle startup script for Windows +@rem +@rem ########################################################################## + +@rem Set local scope for the variables with windows NT shell +if "%OS%"=="Windows_NT" setlocal + +set DIRNAME=%~dp0 +if "%DIRNAME%" == "" set DIRNAME=. +set APP_BASE_NAME=%~n0 +set APP_HOME=%DIRNAME% + +@rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. +set DEFAULT_JVM_OPTS= + +@rem Find java.exe +if defined JAVA_HOME goto findJavaFromJavaHome + +set JAVA_EXE=java.exe +%JAVA_EXE% -version >NUL 2>&1 +if "%ERRORLEVEL%" == "0" goto init + +echo. +echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. +echo. +echo Please set the JAVA_HOME variable in your environment to match the +echo location of your Java installation. + +goto fail + +:findJavaFromJavaHome +set JAVA_HOME=%JAVA_HOME:"=% +set JAVA_EXE=%JAVA_HOME%/bin/java.exe + +if exist "%JAVA_EXE%" goto init + +echo. +echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME% +echo. +echo Please set the JAVA_HOME variable in your environment to match the +echo location of your Java installation. + +goto fail + +:init +@rem Get command-line arguments, handling Windows variants + +if not "%OS%" == "Windows_NT" goto win9xME_args + +:win9xME_args +@rem Slurp the command line arguments. +set CMD_LINE_ARGS= +set _SKIP=2 + +:win9xME_args_slurp +if "x%~1" == "x" goto execute + +set CMD_LINE_ARGS=%* + +:execute +@rem Setup the command line + +set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar + +@rem Execute Gradle +"%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %CMD_LINE_ARGS% + +:end +@rem End local scope for the variables with windows NT shell +if "%ERRORLEVEL%"=="0" goto mainEnd + +:fail +rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of +rem the _cmd.exe /c_ return code! +if not "" == "%GRADLE_EXIT_CONSOLE%" exit 1 +exit /b 1 + +:mainEnd +if "%OS%"=="Windows_NT" endlocal + +:omega diff --git a/models/audio/speech_recognition/conformer/igie/wenet/runtime/binding/python/patch/openfst/src/CMakeLists.txt b/models/audio/speech_recognition/conformer/igie/wenet/runtime/binding/python/patch/openfst/src/CMakeLists.txt index 04fb1c15b6c1098b61f57f3c276d1a7595b20319..04051ef5ae46c04a40c1ffccc98c37fa594ad13e 100644 --- a/models/audio/speech_recognition/conformer/igie/wenet/runtime/binding/python/patch/openfst/src/CMakeLists.txt +++ b/models/audio/speech_recognition/conformer/igie/wenet/runtime/binding/python/patch/openfst/src/CMakeLists.txt @@ -1,23 +1,23 @@ - -#-DHAVE_CONFIG_H -I./../include -fno-exceptions -funsigned-char -std=c++11 -MT symbol-table.lo -MD -MP -MF .deps/symbol-table.Tpo -c symbol-table.cc -fno-common -DPIC -o .libs/symbol-table.o - -include_directories(./include/) -install(DIRECTORY include/ DESTINATION include/ - FILES_MATCHING PATTERN "*.h") - -add_subdirectory(lib) - -if(HAVE_SCRIPT) - add_subdirectory(script) -endif(HAVE_SCRIPT) - -if(HAVE_BIN) - add_subdirectory(bin) -endif(HAVE_BIN) - -add_subdirectory(extensions) - -if(BUILD_TESTING) - enable_testing() - add_subdirectory(test) -endif(BUILD_TESTING) + +#-DHAVE_CONFIG_H -I./../include -fno-exceptions -funsigned-char -std=c++11 -MT symbol-table.lo -MD -MP -MF .deps/symbol-table.Tpo -c symbol-table.cc -fno-common -DPIC -o .libs/symbol-table.o + +include_directories(./include/) +install(DIRECTORY include/ DESTINATION include/ + FILES_MATCHING PATTERN "*.h") + +add_subdirectory(lib) + +if(HAVE_SCRIPT) + add_subdirectory(script) +endif(HAVE_SCRIPT) + +if(HAVE_BIN) + add_subdirectory(bin) +endif(HAVE_BIN) + +add_subdirectory(extensions) + +if(BUILD_TESTING) + enable_testing() + add_subdirectory(test) +endif(BUILD_TESTING) diff --git a/models/audio/speech_recognition/conformer/igie/wenet/runtime/binding/python/utils/thread_pool.h b/models/audio/speech_recognition/conformer/igie/wenet/runtime/binding/python/utils/thread_pool.h index 03cc45b8b2e4bc273706353f248ef088db0eaca8..a78162995d90bf079ad091cf14cb9f2cd4476d05 100644 --- a/models/audio/speech_recognition/conformer/igie/wenet/runtime/binding/python/utils/thread_pool.h +++ b/models/audio/speech_recognition/conformer/igie/wenet/runtime/binding/python/utils/thread_pool.h @@ -1,113 +1,113 @@ -// Copyright (c) 2012 Jakob Progsch, Václav Zeman - -// This software is provided 'as-is', without any express or implied -// warranty. In no event will the authors be held liable for any damages -// arising from the use of this software. - -// Permission is granted to anyone to use this software for any purpose, -// including commercial applications, and to alter it and redistribute it -// freely, subject to the following restrictions: - -// 1. The origin of this software must not be misrepresented; you must not -// claim that you wrote the original software. If you use this software -// in a product, an acknowledgment in the product documentation would be -// appreciated but is not required. - -// 2. Altered source versions must be plainly marked as such, and must not be -// misrepresented as being the original software. - -// 3. This notice may not be removed or altered from any source -// distribution. - -#ifndef UTILS_THREAD_POOL_H_ -#define UTILS_THREAD_POOL_H_ - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -class ThreadPool { - public: - explicit ThreadPool(size_t); - template - auto enqueue(F&& f, Args&&... args) - -> std::future::type>; - ~ThreadPool(); - - private: - // need to keep track of threads so we can join them - std::vector workers; - // the task queue - std::queue > tasks; - - // synchronization - std::mutex queue_mutex; - std::condition_variable condition; - bool stop; -}; - -// the constructor just launches some amount of workers -inline ThreadPool::ThreadPool(size_t threads) : stop(false) { - for (size_t i = 0; i < threads; ++i) - workers.emplace_back([this] { - for (;;) { - std::function task; - - { - std::unique_lock lock(this->queue_mutex); - this->condition.wait( - lock, [this] { return this->stop || !this->tasks.empty(); }); - if (this->stop && this->tasks.empty()) return; - task = std::move(this->tasks.front()); - this->tasks.pop(); - } - - task(); - } - }); -} - -// add new work item to the pool -template -auto ThreadPool::enqueue(F&& f, Args&&... args) - -> std::future::type> { - using return_type = typename std::result_of::type; - - auto task = std::make_shared >( - std::bind(std::forward(f), std::forward(args)...)); - - std::future res = task->get_future(); - { - std::unique_lock lock(queue_mutex); - - // don't allow enqueueing after stopping the pool - if (stop) { - throw std::runtime_error("enqueue on stopped ThreadPool"); - } - - tasks.emplace([task]() { (*task)(); }); - } - condition.notify_one(); - return res; -} - -// the destructor joins all threads -inline ThreadPool::~ThreadPool() { - { - std::unique_lock lock(queue_mutex); - stop = true; - } - condition.notify_all(); - for (std::thread& worker : workers) { - worker.join(); - } -} - -#endif // UTILS_THREAD_POOL_H_ +// Copyright (c) 2012 Jakob Progsch, Václav Zeman + +// This software is provided 'as-is', without any express or implied +// warranty. In no event will the authors be held liable for any damages +// arising from the use of this software. + +// Permission is granted to anyone to use this software for any purpose, +// including commercial applications, and to alter it and redistribute it +// freely, subject to the following restrictions: + +// 1. The origin of this software must not be misrepresented; you must not +// claim that you wrote the original software. If you use this software +// in a product, an acknowledgment in the product documentation would be +// appreciated but is not required. + +// 2. Altered source versions must be plainly marked as such, and must not be +// misrepresented as being the original software. + +// 3. This notice may not be removed or altered from any source +// distribution. + +#ifndef UTILS_THREAD_POOL_H_ +#define UTILS_THREAD_POOL_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +class ThreadPool { + public: + explicit ThreadPool(size_t); + template + auto enqueue(F&& f, Args&&... args) + -> std::future::type>; + ~ThreadPool(); + + private: + // need to keep track of threads so we can join them + std::vector workers; + // the task queue + std::queue > tasks; + + // synchronization + std::mutex queue_mutex; + std::condition_variable condition; + bool stop; +}; + +// the constructor just launches some amount of workers +inline ThreadPool::ThreadPool(size_t threads) : stop(false) { + for (size_t i = 0; i < threads; ++i) + workers.emplace_back([this] { + for (;;) { + std::function task; + + { + std::unique_lock lock(this->queue_mutex); + this->condition.wait( + lock, [this] { return this->stop || !this->tasks.empty(); }); + if (this->stop && this->tasks.empty()) return; + task = std::move(this->tasks.front()); + this->tasks.pop(); + } + + task(); + } + }); +} + +// add new work item to the pool +template +auto ThreadPool::enqueue(F&& f, Args&&... args) + -> std::future::type> { + using return_type = typename std::result_of::type; + + auto task = std::make_shared >( + std::bind(std::forward(f), std::forward(args)...)); + + std::future res = task->get_future(); + { + std::unique_lock lock(queue_mutex); + + // don't allow enqueueing after stopping the pool + if (stop) { + throw std::runtime_error("enqueue on stopped ThreadPool"); + } + + tasks.emplace([task]() { (*task)(); }); + } + condition.notify_one(); + return res; +} + +// the destructor joins all threads +inline ThreadPool::~ThreadPool() { + { + std::unique_lock lock(queue_mutex); + stop = true; + } + condition.notify_all(); + for (std::thread& worker : workers) { + worker.join(); + } +} + +#endif // UTILS_THREAD_POOL_H_ diff --git a/models/audio/speech_recognition/conformer/igie/wenet/runtime/core/patch/openfst/src/CMakeLists.txt b/models/audio/speech_recognition/conformer/igie/wenet/runtime/core/patch/openfst/src/CMakeLists.txt index 04fb1c15b6c1098b61f57f3c276d1a7595b20319..04051ef5ae46c04a40c1ffccc98c37fa594ad13e 100644 --- a/models/audio/speech_recognition/conformer/igie/wenet/runtime/core/patch/openfst/src/CMakeLists.txt +++ b/models/audio/speech_recognition/conformer/igie/wenet/runtime/core/patch/openfst/src/CMakeLists.txt @@ -1,23 +1,23 @@ - -#-DHAVE_CONFIG_H -I./../include -fno-exceptions -funsigned-char -std=c++11 -MT symbol-table.lo -MD -MP -MF .deps/symbol-table.Tpo -c symbol-table.cc -fno-common -DPIC -o .libs/symbol-table.o - -include_directories(./include/) -install(DIRECTORY include/ DESTINATION include/ - FILES_MATCHING PATTERN "*.h") - -add_subdirectory(lib) - -if(HAVE_SCRIPT) - add_subdirectory(script) -endif(HAVE_SCRIPT) - -if(HAVE_BIN) - add_subdirectory(bin) -endif(HAVE_BIN) - -add_subdirectory(extensions) - -if(BUILD_TESTING) - enable_testing() - add_subdirectory(test) -endif(BUILD_TESTING) + +#-DHAVE_CONFIG_H -I./../include -fno-exceptions -funsigned-char -std=c++11 -MT symbol-table.lo -MD -MP -MF .deps/symbol-table.Tpo -c symbol-table.cc -fno-common -DPIC -o .libs/symbol-table.o + +include_directories(./include/) +install(DIRECTORY include/ DESTINATION include/ + FILES_MATCHING PATTERN "*.h") + +add_subdirectory(lib) + +if(HAVE_SCRIPT) + add_subdirectory(script) +endif(HAVE_SCRIPT) + +if(HAVE_BIN) + add_subdirectory(bin) +endif(HAVE_BIN) + +add_subdirectory(extensions) + +if(BUILD_TESTING) + enable_testing() + add_subdirectory(test) +endif(BUILD_TESTING) diff --git a/models/audio/speech_recognition/conformer/igie/wenet/runtime/core/utils/thread_pool.h b/models/audio/speech_recognition/conformer/igie/wenet/runtime/core/utils/thread_pool.h index 03cc45b8b2e4bc273706353f248ef088db0eaca8..a78162995d90bf079ad091cf14cb9f2cd4476d05 100644 --- a/models/audio/speech_recognition/conformer/igie/wenet/runtime/core/utils/thread_pool.h +++ b/models/audio/speech_recognition/conformer/igie/wenet/runtime/core/utils/thread_pool.h @@ -1,113 +1,113 @@ -// Copyright (c) 2012 Jakob Progsch, Václav Zeman - -// This software is provided 'as-is', without any express or implied -// warranty. In no event will the authors be held liable for any damages -// arising from the use of this software. - -// Permission is granted to anyone to use this software for any purpose, -// including commercial applications, and to alter it and redistribute it -// freely, subject to the following restrictions: - -// 1. The origin of this software must not be misrepresented; you must not -// claim that you wrote the original software. If you use this software -// in a product, an acknowledgment in the product documentation would be -// appreciated but is not required. - -// 2. Altered source versions must be plainly marked as such, and must not be -// misrepresented as being the original software. - -// 3. This notice may not be removed or altered from any source -// distribution. - -#ifndef UTILS_THREAD_POOL_H_ -#define UTILS_THREAD_POOL_H_ - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -class ThreadPool { - public: - explicit ThreadPool(size_t); - template - auto enqueue(F&& f, Args&&... args) - -> std::future::type>; - ~ThreadPool(); - - private: - // need to keep track of threads so we can join them - std::vector workers; - // the task queue - std::queue > tasks; - - // synchronization - std::mutex queue_mutex; - std::condition_variable condition; - bool stop; -}; - -// the constructor just launches some amount of workers -inline ThreadPool::ThreadPool(size_t threads) : stop(false) { - for (size_t i = 0; i < threads; ++i) - workers.emplace_back([this] { - for (;;) { - std::function task; - - { - std::unique_lock lock(this->queue_mutex); - this->condition.wait( - lock, [this] { return this->stop || !this->tasks.empty(); }); - if (this->stop && this->tasks.empty()) return; - task = std::move(this->tasks.front()); - this->tasks.pop(); - } - - task(); - } - }); -} - -// add new work item to the pool -template -auto ThreadPool::enqueue(F&& f, Args&&... args) - -> std::future::type> { - using return_type = typename std::result_of::type; - - auto task = std::make_shared >( - std::bind(std::forward(f), std::forward(args)...)); - - std::future res = task->get_future(); - { - std::unique_lock lock(queue_mutex); - - // don't allow enqueueing after stopping the pool - if (stop) { - throw std::runtime_error("enqueue on stopped ThreadPool"); - } - - tasks.emplace([task]() { (*task)(); }); - } - condition.notify_one(); - return res; -} - -// the destructor joins all threads -inline ThreadPool::~ThreadPool() { - { - std::unique_lock lock(queue_mutex); - stop = true; - } - condition.notify_all(); - for (std::thread& worker : workers) { - worker.join(); - } -} - -#endif // UTILS_THREAD_POOL_H_ +// Copyright (c) 2012 Jakob Progsch, Václav Zeman + +// This software is provided 'as-is', without any express or implied +// warranty. In no event will the authors be held liable for any damages +// arising from the use of this software. + +// Permission is granted to anyone to use this software for any purpose, +// including commercial applications, and to alter it and redistribute it +// freely, subject to the following restrictions: + +// 1. The origin of this software must not be misrepresented; you must not +// claim that you wrote the original software. If you use this software +// in a product, an acknowledgment in the product documentation would be +// appreciated but is not required. + +// 2. Altered source versions must be plainly marked as such, and must not be +// misrepresented as being the original software. + +// 3. This notice may not be removed or altered from any source +// distribution. + +#ifndef UTILS_THREAD_POOL_H_ +#define UTILS_THREAD_POOL_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +class ThreadPool { + public: + explicit ThreadPool(size_t); + template + auto enqueue(F&& f, Args&&... args) + -> std::future::type>; + ~ThreadPool(); + + private: + // need to keep track of threads so we can join them + std::vector workers; + // the task queue + std::queue > tasks; + + // synchronization + std::mutex queue_mutex; + std::condition_variable condition; + bool stop; +}; + +// the constructor just launches some amount of workers +inline ThreadPool::ThreadPool(size_t threads) : stop(false) { + for (size_t i = 0; i < threads; ++i) + workers.emplace_back([this] { + for (;;) { + std::function task; + + { + std::unique_lock lock(this->queue_mutex); + this->condition.wait( + lock, [this] { return this->stop || !this->tasks.empty(); }); + if (this->stop && this->tasks.empty()) return; + task = std::move(this->tasks.front()); + this->tasks.pop(); + } + + task(); + } + }); +} + +// add new work item to the pool +template +auto ThreadPool::enqueue(F&& f, Args&&... args) + -> std::future::type> { + using return_type = typename std::result_of::type; + + auto task = std::make_shared >( + std::bind(std::forward(f), std::forward(args)...)); + + std::future res = task->get_future(); + { + std::unique_lock lock(queue_mutex); + + // don't allow enqueueing after stopping the pool + if (stop) { + throw std::runtime_error("enqueue on stopped ThreadPool"); + } + + tasks.emplace([task]() { (*task)(); }); + } + condition.notify_one(); + return res; +} + +// the destructor joins all threads +inline ThreadPool::~ThreadPool() { + { + std::unique_lock lock(queue_mutex); + stop = true; + } + condition.notify_all(); + for (std::thread& worker : workers) { + worker.join(); + } +} + +#endif // UTILS_THREAD_POOL_H_ diff --git a/models/audio/speech_recognition/conformer/igie/wenet/runtime/gpu/tensorrt/LayerNormPlugin/LayerNormPlugin.h b/models/audio/speech_recognition/conformer/igie/wenet/runtime/gpu/tensorrt/LayerNormPlugin/LayerNormPlugin.h index 267666cf4909198aac5d1c09f60b896629b3b788..176bf0ea72bc138c6243969abcc658f676a9a37c 100644 --- a/models/audio/speech_recognition/conformer/igie/wenet/runtime/gpu/tensorrt/LayerNormPlugin/LayerNormPlugin.h +++ b/models/audio/speech_recognition/conformer/igie/wenet/runtime/gpu/tensorrt/LayerNormPlugin/LayerNormPlugin.h @@ -1,230 +1,230 @@ -// Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#ifndef RUNTIME_GPU_TENSORRT_LAYERNORMPLUGIN_LAYERNORMPLUGIN_H_ -#define RUNTIME_GPU_TENSORRT_LAYERNORMPLUGIN_LAYERNORMPLUGIN_H_ - -#include -#include -#include -#include -#include // NOLINT -#include // NOLINT - -#define CEIL_DIVIDE(X, Y) (((X)+(Y)-1)/(Y)) -#define CEIL_TO(X, Y) (((X)+(Y)-1)/(Y)*(Y)) - -template -__device__ T epsilon(); - -template <> -__device__ float epsilon() { - return (float)6.0e-12; // NOLINT -} - -template <> -__device__ half epsilon() { - return (half)6.0e-6; -} - -// +------- Debug wrapper ----------------------------------- -#if DEBUG -#define WHERE_AM_I() do {printf("[%s]:this=->%p\n", __func__, this);} while (0); -#else -#define WHERE_AM_I() -#endif // DEBUG - -// +------- Plguin ------------------------------------------- -namespace { // NOLINT -static const char* PLUGIN_NAME{"LayerNorm"}; -static const char* PLUGIN_VERSION{"1"}; -} // namespace - -namespace nvinfer1 { - -// +------- Plugin body --------------------------------------- -class LayerNormPlugin: public IPluginV2DynamicExt { - private: - std::string name_; - std::string namespace_; - - public: - LayerNormPlugin(const std::string& name) : name_(name) { // NOLINT - WHERE_AM_I(); - } - - LayerNormPlugin(const std::string& name, - const void* data, size_t length) : name_(name) { - WHERE_AM_I(); - } - - LayerNormPlugin() = delete; - - ~LayerNormPlugin() { - WHERE_AM_I(); - } - - size_t getSerializationSize() const noexcept override { - WHERE_AM_I(); - return 0; - } - - void serialize(void *buffer) const noexcept override { - WHERE_AM_I(); - } - - IPluginV2DynamicExt* clone() const noexcept override { - WHERE_AM_I(); - return new LayerNormPlugin(name_); - } - - int getNbOutputs() const noexcept override { - WHERE_AM_I(); - return 1; - } - - DimsExprs getOutputDimensions(int32_t outputIndex, const DimsExprs* inputs, - int32_t nbInputs, - IExprBuilder& exprBuilder) noexcept override { - WHERE_AM_I(); - return inputs[0]; - } - - bool supportsFormatCombination(int32_t pos, const PluginTensorDesc* inOut, - int32_t nbInputs, - int32_t nbOutputs) noexcept override { - WHERE_AM_I(); - if (inOut[pos].format != TensorFormat::kLINEAR) { - return false; - } - - bool res = false; - switch (pos) { - case 0: - res = (inOut[pos].type == DataType::kFLOAT - || inOut[pos].type == DataType::kHALF); break; - case 1: - case 2: - case 3: - res = inOut[pos].type == inOut[0].type; break; - default: // should NOT be here - res = false; break; - } - - return res; - } - - DataType getOutputDataType(int outputIndex, - const DataType* inputTypes, - int nbInputs) const noexcept override { - WHERE_AM_I(); - return inputTypes[0]; - } - - void configurePlugin(const DynamicPluginTensorDesc* in, int32_t nbInputs, - const DynamicPluginTensorDesc* out, - int32_t nbOutputs) noexcept override { - WHERE_AM_I(); - } - - size_t getWorkspaceSize(const PluginTensorDesc* inputs, int32_t nbInputs, - const PluginTensorDesc* outputs, - int32_t nbOutputs) const noexcept override { - WHERE_AM_I(); - return 0; - } - - void setPluginNamespace(const char* szNamespace) noexcept override { - WHERE_AM_I(); - namespace_ = szNamespace; - } - const char* getPluginNamespace() const noexcept override { - WHERE_AM_I(); - return namespace_.c_str(); - } - const char* getPluginType() const noexcept override { - WHERE_AM_I(); - return PLUGIN_NAME; - } - const char* getPluginVersion() const noexcept override { - WHERE_AM_I(); - return PLUGIN_VERSION; - } - int initialize() noexcept override { - WHERE_AM_I(); - return 0; - } - void terminate() noexcept override { - WHERE_AM_I(); - return; - } - - void destroy() noexcept override { - WHERE_AM_I(); - } - - int32_t enqueue(const PluginTensorDesc* inputDesc, - const PluginTensorDesc* outputDesc, - const void* const* inputs, - void* const* outputs, void* workspace, - cudaStream_t stream) noexcept override; -}; // class LayerNormPlugin - -class LayerNormPluginCreator : public IPluginCreator { - private: - static PluginFieldCollection fc_; - static std::vector attr_; - std::string namespace_; - - public: - LayerNormPluginCreator() { - fc_.nbFields = attr_.size(); - fc_.fields = attr_.data(); - } - - ~LayerNormPluginCreator() {} - - IPluginV2* createPlugin(const char* name, - const PluginFieldCollection* fc) noexcept override { - WHERE_AM_I(); - return new LayerNormPlugin(name); - } - - IPluginV2* deserializePlugin(const char* name, const void* serialData, - size_t serialLength) noexcept override { - return new LayerNormPlugin(name, serialData, serialLength); - } - - void setPluginNamespace(const char* szNamespace) noexcept override { - namespace_ = szNamespace; - } - - const char* getPluginNamespace() const noexcept override { - return namespace_.c_str(); - } - - const char* getPluginName() const noexcept override { - return PLUGIN_NAME; - } - - const char* getPluginVersion() const noexcept override { - return PLUGIN_VERSION; - } - - const PluginFieldCollection* getFieldNames() noexcept override { - return &fc_; - } -}; // class LayerNormPluginCreator - -} // namespace nvinfer1 -#endif // RUNTIME_GPU_TENSORRT_LAYERNORMPLUGIN_LAYERNORMPLUGIN_H_ +// Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#ifndef RUNTIME_GPU_TENSORRT_LAYERNORMPLUGIN_LAYERNORMPLUGIN_H_ +#define RUNTIME_GPU_TENSORRT_LAYERNORMPLUGIN_LAYERNORMPLUGIN_H_ + +#include +#include +#include +#include +#include // NOLINT +#include // NOLINT + +#define CEIL_DIVIDE(X, Y) (((X)+(Y)-1)/(Y)) +#define CEIL_TO(X, Y) (((X)+(Y)-1)/(Y)*(Y)) + +template +__device__ T epsilon(); + +template <> +__device__ float epsilon() { + return (float)6.0e-12; // NOLINT +} + +template <> +__device__ half epsilon() { + return (half)6.0e-6; +} + +// +------- Debug wrapper ----------------------------------- +#if DEBUG +#define WHERE_AM_I() do {printf("[%s]:this=->%p\n", __func__, this);} while (0); +#else +#define WHERE_AM_I() +#endif // DEBUG + +// +------- Plguin ------------------------------------------- +namespace { // NOLINT +static const char* PLUGIN_NAME{"LayerNorm"}; +static const char* PLUGIN_VERSION{"1"}; +} // namespace + +namespace nvinfer1 { + +// +------- Plugin body --------------------------------------- +class LayerNormPlugin: public IPluginV2DynamicExt { + private: + std::string name_; + std::string namespace_; + + public: + LayerNormPlugin(const std::string& name) : name_(name) { // NOLINT + WHERE_AM_I(); + } + + LayerNormPlugin(const std::string& name, + const void* data, size_t length) : name_(name) { + WHERE_AM_I(); + } + + LayerNormPlugin() = delete; + + ~LayerNormPlugin() { + WHERE_AM_I(); + } + + size_t getSerializationSize() const noexcept override { + WHERE_AM_I(); + return 0; + } + + void serialize(void *buffer) const noexcept override { + WHERE_AM_I(); + } + + IPluginV2DynamicExt* clone() const noexcept override { + WHERE_AM_I(); + return new LayerNormPlugin(name_); + } + + int getNbOutputs() const noexcept override { + WHERE_AM_I(); + return 1; + } + + DimsExprs getOutputDimensions(int32_t outputIndex, const DimsExprs* inputs, + int32_t nbInputs, + IExprBuilder& exprBuilder) noexcept override { + WHERE_AM_I(); + return inputs[0]; + } + + bool supportsFormatCombination(int32_t pos, const PluginTensorDesc* inOut, + int32_t nbInputs, + int32_t nbOutputs) noexcept override { + WHERE_AM_I(); + if (inOut[pos].format != TensorFormat::kLINEAR) { + return false; + } + + bool res = false; + switch (pos) { + case 0: + res = (inOut[pos].type == DataType::kFLOAT + || inOut[pos].type == DataType::kHALF); break; + case 1: + case 2: + case 3: + res = inOut[pos].type == inOut[0].type; break; + default: // should NOT be here + res = false; break; + } + + return res; + } + + DataType getOutputDataType(int outputIndex, + const DataType* inputTypes, + int nbInputs) const noexcept override { + WHERE_AM_I(); + return inputTypes[0]; + } + + void configurePlugin(const DynamicPluginTensorDesc* in, int32_t nbInputs, + const DynamicPluginTensorDesc* out, + int32_t nbOutputs) noexcept override { + WHERE_AM_I(); + } + + size_t getWorkspaceSize(const PluginTensorDesc* inputs, int32_t nbInputs, + const PluginTensorDesc* outputs, + int32_t nbOutputs) const noexcept override { + WHERE_AM_I(); + return 0; + } + + void setPluginNamespace(const char* szNamespace) noexcept override { + WHERE_AM_I(); + namespace_ = szNamespace; + } + const char* getPluginNamespace() const noexcept override { + WHERE_AM_I(); + return namespace_.c_str(); + } + const char* getPluginType() const noexcept override { + WHERE_AM_I(); + return PLUGIN_NAME; + } + const char* getPluginVersion() const noexcept override { + WHERE_AM_I(); + return PLUGIN_VERSION; + } + int initialize() noexcept override { + WHERE_AM_I(); + return 0; + } + void terminate() noexcept override { + WHERE_AM_I(); + return; + } + + void destroy() noexcept override { + WHERE_AM_I(); + } + + int32_t enqueue(const PluginTensorDesc* inputDesc, + const PluginTensorDesc* outputDesc, + const void* const* inputs, + void* const* outputs, void* workspace, + cudaStream_t stream) noexcept override; +}; // class LayerNormPlugin + +class LayerNormPluginCreator : public IPluginCreator { + private: + static PluginFieldCollection fc_; + static std::vector attr_; + std::string namespace_; + + public: + LayerNormPluginCreator() { + fc_.nbFields = attr_.size(); + fc_.fields = attr_.data(); + } + + ~LayerNormPluginCreator() {} + + IPluginV2* createPlugin(const char* name, + const PluginFieldCollection* fc) noexcept override { + WHERE_AM_I(); + return new LayerNormPlugin(name); + } + + IPluginV2* deserializePlugin(const char* name, const void* serialData, + size_t serialLength) noexcept override { + return new LayerNormPlugin(name, serialData, serialLength); + } + + void setPluginNamespace(const char* szNamespace) noexcept override { + namespace_ = szNamespace; + } + + const char* getPluginNamespace() const noexcept override { + return namespace_.c_str(); + } + + const char* getPluginName() const noexcept override { + return PLUGIN_NAME; + } + + const char* getPluginVersion() const noexcept override { + return PLUGIN_VERSION; + } + + const PluginFieldCollection* getFieldNames() noexcept override { + return &fc_; + } +}; // class LayerNormPluginCreator + +} // namespace nvinfer1 +#endif // RUNTIME_GPU_TENSORRT_LAYERNORMPLUGIN_LAYERNORMPLUGIN_H_ diff --git a/models/audio/speech_recognition/conformer/igie/wenet/runtime/horizonbpu/patch/openfst/src/CMakeLists.txt b/models/audio/speech_recognition/conformer/igie/wenet/runtime/horizonbpu/patch/openfst/src/CMakeLists.txt index 04fb1c15b6c1098b61f57f3c276d1a7595b20319..04051ef5ae46c04a40c1ffccc98c37fa594ad13e 100644 --- a/models/audio/speech_recognition/conformer/igie/wenet/runtime/horizonbpu/patch/openfst/src/CMakeLists.txt +++ b/models/audio/speech_recognition/conformer/igie/wenet/runtime/horizonbpu/patch/openfst/src/CMakeLists.txt @@ -1,23 +1,23 @@ - -#-DHAVE_CONFIG_H -I./../include -fno-exceptions -funsigned-char -std=c++11 -MT symbol-table.lo -MD -MP -MF .deps/symbol-table.Tpo -c symbol-table.cc -fno-common -DPIC -o .libs/symbol-table.o - -include_directories(./include/) -install(DIRECTORY include/ DESTINATION include/ - FILES_MATCHING PATTERN "*.h") - -add_subdirectory(lib) - -if(HAVE_SCRIPT) - add_subdirectory(script) -endif(HAVE_SCRIPT) - -if(HAVE_BIN) - add_subdirectory(bin) -endif(HAVE_BIN) - -add_subdirectory(extensions) - -if(BUILD_TESTING) - enable_testing() - add_subdirectory(test) -endif(BUILD_TESTING) + +#-DHAVE_CONFIG_H -I./../include -fno-exceptions -funsigned-char -std=c++11 -MT symbol-table.lo -MD -MP -MF .deps/symbol-table.Tpo -c symbol-table.cc -fno-common -DPIC -o .libs/symbol-table.o + +include_directories(./include/) +install(DIRECTORY include/ DESTINATION include/ + FILES_MATCHING PATTERN "*.h") + +add_subdirectory(lib) + +if(HAVE_SCRIPT) + add_subdirectory(script) +endif(HAVE_SCRIPT) + +if(HAVE_BIN) + add_subdirectory(bin) +endif(HAVE_BIN) + +add_subdirectory(extensions) + +if(BUILD_TESTING) + enable_testing() + add_subdirectory(test) +endif(BUILD_TESTING) diff --git a/models/audio/speech_recognition/conformer/igie/wenet/runtime/horizonbpu/utils/thread_pool.h b/models/audio/speech_recognition/conformer/igie/wenet/runtime/horizonbpu/utils/thread_pool.h index 03cc45b8b2e4bc273706353f248ef088db0eaca8..a78162995d90bf079ad091cf14cb9f2cd4476d05 100644 --- a/models/audio/speech_recognition/conformer/igie/wenet/runtime/horizonbpu/utils/thread_pool.h +++ b/models/audio/speech_recognition/conformer/igie/wenet/runtime/horizonbpu/utils/thread_pool.h @@ -1,113 +1,113 @@ -// Copyright (c) 2012 Jakob Progsch, Václav Zeman - -// This software is provided 'as-is', without any express or implied -// warranty. In no event will the authors be held liable for any damages -// arising from the use of this software. - -// Permission is granted to anyone to use this software for any purpose, -// including commercial applications, and to alter it and redistribute it -// freely, subject to the following restrictions: - -// 1. The origin of this software must not be misrepresented; you must not -// claim that you wrote the original software. If you use this software -// in a product, an acknowledgment in the product documentation would be -// appreciated but is not required. - -// 2. Altered source versions must be plainly marked as such, and must not be -// misrepresented as being the original software. - -// 3. This notice may not be removed or altered from any source -// distribution. - -#ifndef UTILS_THREAD_POOL_H_ -#define UTILS_THREAD_POOL_H_ - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -class ThreadPool { - public: - explicit ThreadPool(size_t); - template - auto enqueue(F&& f, Args&&... args) - -> std::future::type>; - ~ThreadPool(); - - private: - // need to keep track of threads so we can join them - std::vector workers; - // the task queue - std::queue > tasks; - - // synchronization - std::mutex queue_mutex; - std::condition_variable condition; - bool stop; -}; - -// the constructor just launches some amount of workers -inline ThreadPool::ThreadPool(size_t threads) : stop(false) { - for (size_t i = 0; i < threads; ++i) - workers.emplace_back([this] { - for (;;) { - std::function task; - - { - std::unique_lock lock(this->queue_mutex); - this->condition.wait( - lock, [this] { return this->stop || !this->tasks.empty(); }); - if (this->stop && this->tasks.empty()) return; - task = std::move(this->tasks.front()); - this->tasks.pop(); - } - - task(); - } - }); -} - -// add new work item to the pool -template -auto ThreadPool::enqueue(F&& f, Args&&... args) - -> std::future::type> { - using return_type = typename std::result_of::type; - - auto task = std::make_shared >( - std::bind(std::forward(f), std::forward(args)...)); - - std::future res = task->get_future(); - { - std::unique_lock lock(queue_mutex); - - // don't allow enqueueing after stopping the pool - if (stop) { - throw std::runtime_error("enqueue on stopped ThreadPool"); - } - - tasks.emplace([task]() { (*task)(); }); - } - condition.notify_one(); - return res; -} - -// the destructor joins all threads -inline ThreadPool::~ThreadPool() { - { - std::unique_lock lock(queue_mutex); - stop = true; - } - condition.notify_all(); - for (std::thread& worker : workers) { - worker.join(); - } -} - -#endif // UTILS_THREAD_POOL_H_ +// Copyright (c) 2012 Jakob Progsch, Václav Zeman + +// This software is provided 'as-is', without any express or implied +// warranty. In no event will the authors be held liable for any damages +// arising from the use of this software. + +// Permission is granted to anyone to use this software for any purpose, +// including commercial applications, and to alter it and redistribute it +// freely, subject to the following restrictions: + +// 1. The origin of this software must not be misrepresented; you must not +// claim that you wrote the original software. If you use this software +// in a product, an acknowledgment in the product documentation would be +// appreciated but is not required. + +// 2. Altered source versions must be plainly marked as such, and must not be +// misrepresented as being the original software. + +// 3. This notice may not be removed or altered from any source +// distribution. + +#ifndef UTILS_THREAD_POOL_H_ +#define UTILS_THREAD_POOL_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +class ThreadPool { + public: + explicit ThreadPool(size_t); + template + auto enqueue(F&& f, Args&&... args) + -> std::future::type>; + ~ThreadPool(); + + private: + // need to keep track of threads so we can join them + std::vector workers; + // the task queue + std::queue > tasks; + + // synchronization + std::mutex queue_mutex; + std::condition_variable condition; + bool stop; +}; + +// the constructor just launches some amount of workers +inline ThreadPool::ThreadPool(size_t threads) : stop(false) { + for (size_t i = 0; i < threads; ++i) + workers.emplace_back([this] { + for (;;) { + std::function task; + + { + std::unique_lock lock(this->queue_mutex); + this->condition.wait( + lock, [this] { return this->stop || !this->tasks.empty(); }); + if (this->stop && this->tasks.empty()) return; + task = std::move(this->tasks.front()); + this->tasks.pop(); + } + + task(); + } + }); +} + +// add new work item to the pool +template +auto ThreadPool::enqueue(F&& f, Args&&... args) + -> std::future::type> { + using return_type = typename std::result_of::type; + + auto task = std::make_shared >( + std::bind(std::forward(f), std::forward(args)...)); + + std::future res = task->get_future(); + { + std::unique_lock lock(queue_mutex); + + // don't allow enqueueing after stopping the pool + if (stop) { + throw std::runtime_error("enqueue on stopped ThreadPool"); + } + + tasks.emplace([task]() { (*task)(); }); + } + condition.notify_one(); + return res; +} + +// the destructor joins all threads +inline ThreadPool::~ThreadPool() { + { + std::unique_lock lock(queue_mutex); + stop = true; + } + condition.notify_all(); + for (std::thread& worker : workers) { + worker.join(); + } +} + +#endif // UTILS_THREAD_POOL_H_ diff --git a/models/audio/speech_recognition/conformer/igie/wenet/runtime/ios/patch/openfst/src/CMakeLists.txt b/models/audio/speech_recognition/conformer/igie/wenet/runtime/ios/patch/openfst/src/CMakeLists.txt index 04fb1c15b6c1098b61f57f3c276d1a7595b20319..04051ef5ae46c04a40c1ffccc98c37fa594ad13e 100644 --- a/models/audio/speech_recognition/conformer/igie/wenet/runtime/ios/patch/openfst/src/CMakeLists.txt +++ b/models/audio/speech_recognition/conformer/igie/wenet/runtime/ios/patch/openfst/src/CMakeLists.txt @@ -1,23 +1,23 @@ - -#-DHAVE_CONFIG_H -I./../include -fno-exceptions -funsigned-char -std=c++11 -MT symbol-table.lo -MD -MP -MF .deps/symbol-table.Tpo -c symbol-table.cc -fno-common -DPIC -o .libs/symbol-table.o - -include_directories(./include/) -install(DIRECTORY include/ DESTINATION include/ - FILES_MATCHING PATTERN "*.h") - -add_subdirectory(lib) - -if(HAVE_SCRIPT) - add_subdirectory(script) -endif(HAVE_SCRIPT) - -if(HAVE_BIN) - add_subdirectory(bin) -endif(HAVE_BIN) - -add_subdirectory(extensions) - -if(BUILD_TESTING) - enable_testing() - add_subdirectory(test) -endif(BUILD_TESTING) + +#-DHAVE_CONFIG_H -I./../include -fno-exceptions -funsigned-char -std=c++11 -MT symbol-table.lo -MD -MP -MF .deps/symbol-table.Tpo -c symbol-table.cc -fno-common -DPIC -o .libs/symbol-table.o + +include_directories(./include/) +install(DIRECTORY include/ DESTINATION include/ + FILES_MATCHING PATTERN "*.h") + +add_subdirectory(lib) + +if(HAVE_SCRIPT) + add_subdirectory(script) +endif(HAVE_SCRIPT) + +if(HAVE_BIN) + add_subdirectory(bin) +endif(HAVE_BIN) + +add_subdirectory(extensions) + +if(BUILD_TESTING) + enable_testing() + add_subdirectory(test) +endif(BUILD_TESTING) diff --git a/models/audio/speech_recognition/conformer/igie/wenet/runtime/ios/utils/thread_pool.h b/models/audio/speech_recognition/conformer/igie/wenet/runtime/ios/utils/thread_pool.h index 03cc45b8b2e4bc273706353f248ef088db0eaca8..a78162995d90bf079ad091cf14cb9f2cd4476d05 100644 --- a/models/audio/speech_recognition/conformer/igie/wenet/runtime/ios/utils/thread_pool.h +++ b/models/audio/speech_recognition/conformer/igie/wenet/runtime/ios/utils/thread_pool.h @@ -1,113 +1,113 @@ -// Copyright (c) 2012 Jakob Progsch, Václav Zeman - -// This software is provided 'as-is', without any express or implied -// warranty. In no event will the authors be held liable for any damages -// arising from the use of this software. - -// Permission is granted to anyone to use this software for any purpose, -// including commercial applications, and to alter it and redistribute it -// freely, subject to the following restrictions: - -// 1. The origin of this software must not be misrepresented; you must not -// claim that you wrote the original software. If you use this software -// in a product, an acknowledgment in the product documentation would be -// appreciated but is not required. - -// 2. Altered source versions must be plainly marked as such, and must not be -// misrepresented as being the original software. - -// 3. This notice may not be removed or altered from any source -// distribution. - -#ifndef UTILS_THREAD_POOL_H_ -#define UTILS_THREAD_POOL_H_ - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -class ThreadPool { - public: - explicit ThreadPool(size_t); - template - auto enqueue(F&& f, Args&&... args) - -> std::future::type>; - ~ThreadPool(); - - private: - // need to keep track of threads so we can join them - std::vector workers; - // the task queue - std::queue > tasks; - - // synchronization - std::mutex queue_mutex; - std::condition_variable condition; - bool stop; -}; - -// the constructor just launches some amount of workers -inline ThreadPool::ThreadPool(size_t threads) : stop(false) { - for (size_t i = 0; i < threads; ++i) - workers.emplace_back([this] { - for (;;) { - std::function task; - - { - std::unique_lock lock(this->queue_mutex); - this->condition.wait( - lock, [this] { return this->stop || !this->tasks.empty(); }); - if (this->stop && this->tasks.empty()) return; - task = std::move(this->tasks.front()); - this->tasks.pop(); - } - - task(); - } - }); -} - -// add new work item to the pool -template -auto ThreadPool::enqueue(F&& f, Args&&... args) - -> std::future::type> { - using return_type = typename std::result_of::type; - - auto task = std::make_shared >( - std::bind(std::forward(f), std::forward(args)...)); - - std::future res = task->get_future(); - { - std::unique_lock lock(queue_mutex); - - // don't allow enqueueing after stopping the pool - if (stop) { - throw std::runtime_error("enqueue on stopped ThreadPool"); - } - - tasks.emplace([task]() { (*task)(); }); - } - condition.notify_one(); - return res; -} - -// the destructor joins all threads -inline ThreadPool::~ThreadPool() { - { - std::unique_lock lock(queue_mutex); - stop = true; - } - condition.notify_all(); - for (std::thread& worker : workers) { - worker.join(); - } -} - -#endif // UTILS_THREAD_POOL_H_ +// Copyright (c) 2012 Jakob Progsch, Václav Zeman + +// This software is provided 'as-is', without any express or implied +// warranty. In no event will the authors be held liable for any damages +// arising from the use of this software. + +// Permission is granted to anyone to use this software for any purpose, +// including commercial applications, and to alter it and redistribute it +// freely, subject to the following restrictions: + +// 1. The origin of this software must not be misrepresented; you must not +// claim that you wrote the original software. If you use this software +// in a product, an acknowledgment in the product documentation would be +// appreciated but is not required. + +// 2. Altered source versions must be plainly marked as such, and must not be +// misrepresented as being the original software. + +// 3. This notice may not be removed or altered from any source +// distribution. + +#ifndef UTILS_THREAD_POOL_H_ +#define UTILS_THREAD_POOL_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +class ThreadPool { + public: + explicit ThreadPool(size_t); + template + auto enqueue(F&& f, Args&&... args) + -> std::future::type>; + ~ThreadPool(); + + private: + // need to keep track of threads so we can join them + std::vector workers; + // the task queue + std::queue > tasks; + + // synchronization + std::mutex queue_mutex; + std::condition_variable condition; + bool stop; +}; + +// the constructor just launches some amount of workers +inline ThreadPool::ThreadPool(size_t threads) : stop(false) { + for (size_t i = 0; i < threads; ++i) + workers.emplace_back([this] { + for (;;) { + std::function task; + + { + std::unique_lock lock(this->queue_mutex); + this->condition.wait( + lock, [this] { return this->stop || !this->tasks.empty(); }); + if (this->stop && this->tasks.empty()) return; + task = std::move(this->tasks.front()); + this->tasks.pop(); + } + + task(); + } + }); +} + +// add new work item to the pool +template +auto ThreadPool::enqueue(F&& f, Args&&... args) + -> std::future::type> { + using return_type = typename std::result_of::type; + + auto task = std::make_shared >( + std::bind(std::forward(f), std::forward(args)...)); + + std::future res = task->get_future(); + { + std::unique_lock lock(queue_mutex); + + // don't allow enqueueing after stopping the pool + if (stop) { + throw std::runtime_error("enqueue on stopped ThreadPool"); + } + + tasks.emplace([task]() { (*task)(); }); + } + condition.notify_one(); + return res; +} + +// the destructor joins all threads +inline ThreadPool::~ThreadPool() { + { + std::unique_lock lock(queue_mutex); + stop = true; + } + condition.notify_all(); + for (std::thread& worker : workers) { + worker.join(); + } +} + +#endif // UTILS_THREAD_POOL_H_ diff --git a/models/audio/speech_recognition/conformer/igie/wenet/runtime/kunlun/patch/openfst/src/CMakeLists.txt b/models/audio/speech_recognition/conformer/igie/wenet/runtime/kunlun/patch/openfst/src/CMakeLists.txt index 04fb1c15b6c1098b61f57f3c276d1a7595b20319..04051ef5ae46c04a40c1ffccc98c37fa594ad13e 100644 --- a/models/audio/speech_recognition/conformer/igie/wenet/runtime/kunlun/patch/openfst/src/CMakeLists.txt +++ b/models/audio/speech_recognition/conformer/igie/wenet/runtime/kunlun/patch/openfst/src/CMakeLists.txt @@ -1,23 +1,23 @@ - -#-DHAVE_CONFIG_H -I./../include -fno-exceptions -funsigned-char -std=c++11 -MT symbol-table.lo -MD -MP -MF .deps/symbol-table.Tpo -c symbol-table.cc -fno-common -DPIC -o .libs/symbol-table.o - -include_directories(./include/) -install(DIRECTORY include/ DESTINATION include/ - FILES_MATCHING PATTERN "*.h") - -add_subdirectory(lib) - -if(HAVE_SCRIPT) - add_subdirectory(script) -endif(HAVE_SCRIPT) - -if(HAVE_BIN) - add_subdirectory(bin) -endif(HAVE_BIN) - -add_subdirectory(extensions) - -if(BUILD_TESTING) - enable_testing() - add_subdirectory(test) -endif(BUILD_TESTING) + +#-DHAVE_CONFIG_H -I./../include -fno-exceptions -funsigned-char -std=c++11 -MT symbol-table.lo -MD -MP -MF .deps/symbol-table.Tpo -c symbol-table.cc -fno-common -DPIC -o .libs/symbol-table.o + +include_directories(./include/) +install(DIRECTORY include/ DESTINATION include/ + FILES_MATCHING PATTERN "*.h") + +add_subdirectory(lib) + +if(HAVE_SCRIPT) + add_subdirectory(script) +endif(HAVE_SCRIPT) + +if(HAVE_BIN) + add_subdirectory(bin) +endif(HAVE_BIN) + +add_subdirectory(extensions) + +if(BUILD_TESTING) + enable_testing() + add_subdirectory(test) +endif(BUILD_TESTING) diff --git a/models/audio/speech_recognition/conformer/igie/wenet/runtime/kunlun/utils/thread_pool.h b/models/audio/speech_recognition/conformer/igie/wenet/runtime/kunlun/utils/thread_pool.h index 03cc45b8b2e4bc273706353f248ef088db0eaca8..a78162995d90bf079ad091cf14cb9f2cd4476d05 100644 --- a/models/audio/speech_recognition/conformer/igie/wenet/runtime/kunlun/utils/thread_pool.h +++ b/models/audio/speech_recognition/conformer/igie/wenet/runtime/kunlun/utils/thread_pool.h @@ -1,113 +1,113 @@ -// Copyright (c) 2012 Jakob Progsch, Václav Zeman - -// This software is provided 'as-is', without any express or implied -// warranty. In no event will the authors be held liable for any damages -// arising from the use of this software. - -// Permission is granted to anyone to use this software for any purpose, -// including commercial applications, and to alter it and redistribute it -// freely, subject to the following restrictions: - -// 1. The origin of this software must not be misrepresented; you must not -// claim that you wrote the original software. If you use this software -// in a product, an acknowledgment in the product documentation would be -// appreciated but is not required. - -// 2. Altered source versions must be plainly marked as such, and must not be -// misrepresented as being the original software. - -// 3. This notice may not be removed or altered from any source -// distribution. - -#ifndef UTILS_THREAD_POOL_H_ -#define UTILS_THREAD_POOL_H_ - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -class ThreadPool { - public: - explicit ThreadPool(size_t); - template - auto enqueue(F&& f, Args&&... args) - -> std::future::type>; - ~ThreadPool(); - - private: - // need to keep track of threads so we can join them - std::vector workers; - // the task queue - std::queue > tasks; - - // synchronization - std::mutex queue_mutex; - std::condition_variable condition; - bool stop; -}; - -// the constructor just launches some amount of workers -inline ThreadPool::ThreadPool(size_t threads) : stop(false) { - for (size_t i = 0; i < threads; ++i) - workers.emplace_back([this] { - for (;;) { - std::function task; - - { - std::unique_lock lock(this->queue_mutex); - this->condition.wait( - lock, [this] { return this->stop || !this->tasks.empty(); }); - if (this->stop && this->tasks.empty()) return; - task = std::move(this->tasks.front()); - this->tasks.pop(); - } - - task(); - } - }); -} - -// add new work item to the pool -template -auto ThreadPool::enqueue(F&& f, Args&&... args) - -> std::future::type> { - using return_type = typename std::result_of::type; - - auto task = std::make_shared >( - std::bind(std::forward(f), std::forward(args)...)); - - std::future res = task->get_future(); - { - std::unique_lock lock(queue_mutex); - - // don't allow enqueueing after stopping the pool - if (stop) { - throw std::runtime_error("enqueue on stopped ThreadPool"); - } - - tasks.emplace([task]() { (*task)(); }); - } - condition.notify_one(); - return res; -} - -// the destructor joins all threads -inline ThreadPool::~ThreadPool() { - { - std::unique_lock lock(queue_mutex); - stop = true; - } - condition.notify_all(); - for (std::thread& worker : workers) { - worker.join(); - } -} - -#endif // UTILS_THREAD_POOL_H_ +// Copyright (c) 2012 Jakob Progsch, Václav Zeman + +// This software is provided 'as-is', without any express or implied +// warranty. In no event will the authors be held liable for any damages +// arising from the use of this software. + +// Permission is granted to anyone to use this software for any purpose, +// including commercial applications, and to alter it and redistribute it +// freely, subject to the following restrictions: + +// 1. The origin of this software must not be misrepresented; you must not +// claim that you wrote the original software. If you use this software +// in a product, an acknowledgment in the product documentation would be +// appreciated but is not required. + +// 2. Altered source versions must be plainly marked as such, and must not be +// misrepresented as being the original software. + +// 3. This notice may not be removed or altered from any source +// distribution. + +#ifndef UTILS_THREAD_POOL_H_ +#define UTILS_THREAD_POOL_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +class ThreadPool { + public: + explicit ThreadPool(size_t); + template + auto enqueue(F&& f, Args&&... args) + -> std::future::type>; + ~ThreadPool(); + + private: + // need to keep track of threads so we can join them + std::vector workers; + // the task queue + std::queue > tasks; + + // synchronization + std::mutex queue_mutex; + std::condition_variable condition; + bool stop; +}; + +// the constructor just launches some amount of workers +inline ThreadPool::ThreadPool(size_t threads) : stop(false) { + for (size_t i = 0; i < threads; ++i) + workers.emplace_back([this] { + for (;;) { + std::function task; + + { + std::unique_lock lock(this->queue_mutex); + this->condition.wait( + lock, [this] { return this->stop || !this->tasks.empty(); }); + if (this->stop && this->tasks.empty()) return; + task = std::move(this->tasks.front()); + this->tasks.pop(); + } + + task(); + } + }); +} + +// add new work item to the pool +template +auto ThreadPool::enqueue(F&& f, Args&&... args) + -> std::future::type> { + using return_type = typename std::result_of::type; + + auto task = std::make_shared >( + std::bind(std::forward(f), std::forward(args)...)); + + std::future res = task->get_future(); + { + std::unique_lock lock(queue_mutex); + + // don't allow enqueueing after stopping the pool + if (stop) { + throw std::runtime_error("enqueue on stopped ThreadPool"); + } + + tasks.emplace([task]() { (*task)(); }); + } + condition.notify_one(); + return res; +} + +// the destructor joins all threads +inline ThreadPool::~ThreadPool() { + { + std::unique_lock lock(queue_mutex); + stop = true; + } + condition.notify_all(); + for (std::thread& worker : workers) { + worker.join(); + } +} + +#endif // UTILS_THREAD_POOL_H_ diff --git a/models/audio/speech_recognition/conformer/igie/wenet/runtime/libtorch/patch/openfst/src/CMakeLists.txt b/models/audio/speech_recognition/conformer/igie/wenet/runtime/libtorch/patch/openfst/src/CMakeLists.txt index 04fb1c15b6c1098b61f57f3c276d1a7595b20319..04051ef5ae46c04a40c1ffccc98c37fa594ad13e 100644 --- a/models/audio/speech_recognition/conformer/igie/wenet/runtime/libtorch/patch/openfst/src/CMakeLists.txt +++ b/models/audio/speech_recognition/conformer/igie/wenet/runtime/libtorch/patch/openfst/src/CMakeLists.txt @@ -1,23 +1,23 @@ - -#-DHAVE_CONFIG_H -I./../include -fno-exceptions -funsigned-char -std=c++11 -MT symbol-table.lo -MD -MP -MF .deps/symbol-table.Tpo -c symbol-table.cc -fno-common -DPIC -o .libs/symbol-table.o - -include_directories(./include/) -install(DIRECTORY include/ DESTINATION include/ - FILES_MATCHING PATTERN "*.h") - -add_subdirectory(lib) - -if(HAVE_SCRIPT) - add_subdirectory(script) -endif(HAVE_SCRIPT) - -if(HAVE_BIN) - add_subdirectory(bin) -endif(HAVE_BIN) - -add_subdirectory(extensions) - -if(BUILD_TESTING) - enable_testing() - add_subdirectory(test) -endif(BUILD_TESTING) + +#-DHAVE_CONFIG_H -I./../include -fno-exceptions -funsigned-char -std=c++11 -MT symbol-table.lo -MD -MP -MF .deps/symbol-table.Tpo -c symbol-table.cc -fno-common -DPIC -o .libs/symbol-table.o + +include_directories(./include/) +install(DIRECTORY include/ DESTINATION include/ + FILES_MATCHING PATTERN "*.h") + +add_subdirectory(lib) + +if(HAVE_SCRIPT) + add_subdirectory(script) +endif(HAVE_SCRIPT) + +if(HAVE_BIN) + add_subdirectory(bin) +endif(HAVE_BIN) + +add_subdirectory(extensions) + +if(BUILD_TESTING) + enable_testing() + add_subdirectory(test) +endif(BUILD_TESTING) diff --git a/models/audio/speech_recognition/conformer/igie/wenet/runtime/libtorch/utils/thread_pool.h b/models/audio/speech_recognition/conformer/igie/wenet/runtime/libtorch/utils/thread_pool.h index 03cc45b8b2e4bc273706353f248ef088db0eaca8..a78162995d90bf079ad091cf14cb9f2cd4476d05 100644 --- a/models/audio/speech_recognition/conformer/igie/wenet/runtime/libtorch/utils/thread_pool.h +++ b/models/audio/speech_recognition/conformer/igie/wenet/runtime/libtorch/utils/thread_pool.h @@ -1,113 +1,113 @@ -// Copyright (c) 2012 Jakob Progsch, Václav Zeman - -// This software is provided 'as-is', without any express or implied -// warranty. In no event will the authors be held liable for any damages -// arising from the use of this software. - -// Permission is granted to anyone to use this software for any purpose, -// including commercial applications, and to alter it and redistribute it -// freely, subject to the following restrictions: - -// 1. The origin of this software must not be misrepresented; you must not -// claim that you wrote the original software. If you use this software -// in a product, an acknowledgment in the product documentation would be -// appreciated but is not required. - -// 2. Altered source versions must be plainly marked as such, and must not be -// misrepresented as being the original software. - -// 3. This notice may not be removed or altered from any source -// distribution. - -#ifndef UTILS_THREAD_POOL_H_ -#define UTILS_THREAD_POOL_H_ - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -class ThreadPool { - public: - explicit ThreadPool(size_t); - template - auto enqueue(F&& f, Args&&... args) - -> std::future::type>; - ~ThreadPool(); - - private: - // need to keep track of threads so we can join them - std::vector workers; - // the task queue - std::queue > tasks; - - // synchronization - std::mutex queue_mutex; - std::condition_variable condition; - bool stop; -}; - -// the constructor just launches some amount of workers -inline ThreadPool::ThreadPool(size_t threads) : stop(false) { - for (size_t i = 0; i < threads; ++i) - workers.emplace_back([this] { - for (;;) { - std::function task; - - { - std::unique_lock lock(this->queue_mutex); - this->condition.wait( - lock, [this] { return this->stop || !this->tasks.empty(); }); - if (this->stop && this->tasks.empty()) return; - task = std::move(this->tasks.front()); - this->tasks.pop(); - } - - task(); - } - }); -} - -// add new work item to the pool -template -auto ThreadPool::enqueue(F&& f, Args&&... args) - -> std::future::type> { - using return_type = typename std::result_of::type; - - auto task = std::make_shared >( - std::bind(std::forward(f), std::forward(args)...)); - - std::future res = task->get_future(); - { - std::unique_lock lock(queue_mutex); - - // don't allow enqueueing after stopping the pool - if (stop) { - throw std::runtime_error("enqueue on stopped ThreadPool"); - } - - tasks.emplace([task]() { (*task)(); }); - } - condition.notify_one(); - return res; -} - -// the destructor joins all threads -inline ThreadPool::~ThreadPool() { - { - std::unique_lock lock(queue_mutex); - stop = true; - } - condition.notify_all(); - for (std::thread& worker : workers) { - worker.join(); - } -} - -#endif // UTILS_THREAD_POOL_H_ +// Copyright (c) 2012 Jakob Progsch, Václav Zeman + +// This software is provided 'as-is', without any express or implied +// warranty. In no event will the authors be held liable for any damages +// arising from the use of this software. + +// Permission is granted to anyone to use this software for any purpose, +// including commercial applications, and to alter it and redistribute it +// freely, subject to the following restrictions: + +// 1. The origin of this software must not be misrepresented; you must not +// claim that you wrote the original software. If you use this software +// in a product, an acknowledgment in the product documentation would be +// appreciated but is not required. + +// 2. Altered source versions must be plainly marked as such, and must not be +// misrepresented as being the original software. + +// 3. This notice may not be removed or altered from any source +// distribution. + +#ifndef UTILS_THREAD_POOL_H_ +#define UTILS_THREAD_POOL_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +class ThreadPool { + public: + explicit ThreadPool(size_t); + template + auto enqueue(F&& f, Args&&... args) + -> std::future::type>; + ~ThreadPool(); + + private: + // need to keep track of threads so we can join them + std::vector workers; + // the task queue + std::queue > tasks; + + // synchronization + std::mutex queue_mutex; + std::condition_variable condition; + bool stop; +}; + +// the constructor just launches some amount of workers +inline ThreadPool::ThreadPool(size_t threads) : stop(false) { + for (size_t i = 0; i < threads; ++i) + workers.emplace_back([this] { + for (;;) { + std::function task; + + { + std::unique_lock lock(this->queue_mutex); + this->condition.wait( + lock, [this] { return this->stop || !this->tasks.empty(); }); + if (this->stop && this->tasks.empty()) return; + task = std::move(this->tasks.front()); + this->tasks.pop(); + } + + task(); + } + }); +} + +// add new work item to the pool +template +auto ThreadPool::enqueue(F&& f, Args&&... args) + -> std::future::type> { + using return_type = typename std::result_of::type; + + auto task = std::make_shared >( + std::bind(std::forward(f), std::forward(args)...)); + + std::future res = task->get_future(); + { + std::unique_lock lock(queue_mutex); + + // don't allow enqueueing after stopping the pool + if (stop) { + throw std::runtime_error("enqueue on stopped ThreadPool"); + } + + tasks.emplace([task]() { (*task)(); }); + } + condition.notify_one(); + return res; +} + +// the destructor joins all threads +inline ThreadPool::~ThreadPool() { + { + std::unique_lock lock(queue_mutex); + stop = true; + } + condition.notify_all(); + for (std::thread& worker : workers) { + worker.join(); + } +} + +#endif // UTILS_THREAD_POOL_H_ diff --git a/models/audio/speech_recognition/conformer/igie/wenet/runtime/onnxruntime/patch/openfst/src/CMakeLists.txt b/models/audio/speech_recognition/conformer/igie/wenet/runtime/onnxruntime/patch/openfst/src/CMakeLists.txt index 04fb1c15b6c1098b61f57f3c276d1a7595b20319..04051ef5ae46c04a40c1ffccc98c37fa594ad13e 100644 --- a/models/audio/speech_recognition/conformer/igie/wenet/runtime/onnxruntime/patch/openfst/src/CMakeLists.txt +++ b/models/audio/speech_recognition/conformer/igie/wenet/runtime/onnxruntime/patch/openfst/src/CMakeLists.txt @@ -1,23 +1,23 @@ - -#-DHAVE_CONFIG_H -I./../include -fno-exceptions -funsigned-char -std=c++11 -MT symbol-table.lo -MD -MP -MF .deps/symbol-table.Tpo -c symbol-table.cc -fno-common -DPIC -o .libs/symbol-table.o - -include_directories(./include/) -install(DIRECTORY include/ DESTINATION include/ - FILES_MATCHING PATTERN "*.h") - -add_subdirectory(lib) - -if(HAVE_SCRIPT) - add_subdirectory(script) -endif(HAVE_SCRIPT) - -if(HAVE_BIN) - add_subdirectory(bin) -endif(HAVE_BIN) - -add_subdirectory(extensions) - -if(BUILD_TESTING) - enable_testing() - add_subdirectory(test) -endif(BUILD_TESTING) + +#-DHAVE_CONFIG_H -I./../include -fno-exceptions -funsigned-char -std=c++11 -MT symbol-table.lo -MD -MP -MF .deps/symbol-table.Tpo -c symbol-table.cc -fno-common -DPIC -o .libs/symbol-table.o + +include_directories(./include/) +install(DIRECTORY include/ DESTINATION include/ + FILES_MATCHING PATTERN "*.h") + +add_subdirectory(lib) + +if(HAVE_SCRIPT) + add_subdirectory(script) +endif(HAVE_SCRIPT) + +if(HAVE_BIN) + add_subdirectory(bin) +endif(HAVE_BIN) + +add_subdirectory(extensions) + +if(BUILD_TESTING) + enable_testing() + add_subdirectory(test) +endif(BUILD_TESTING) diff --git a/models/audio/speech_recognition/conformer/igie/wenet/runtime/onnxruntime/utils/thread_pool.h b/models/audio/speech_recognition/conformer/igie/wenet/runtime/onnxruntime/utils/thread_pool.h index 03cc45b8b2e4bc273706353f248ef088db0eaca8..a78162995d90bf079ad091cf14cb9f2cd4476d05 100644 --- a/models/audio/speech_recognition/conformer/igie/wenet/runtime/onnxruntime/utils/thread_pool.h +++ b/models/audio/speech_recognition/conformer/igie/wenet/runtime/onnxruntime/utils/thread_pool.h @@ -1,113 +1,113 @@ -// Copyright (c) 2012 Jakob Progsch, Václav Zeman - -// This software is provided 'as-is', without any express or implied -// warranty. In no event will the authors be held liable for any damages -// arising from the use of this software. - -// Permission is granted to anyone to use this software for any purpose, -// including commercial applications, and to alter it and redistribute it -// freely, subject to the following restrictions: - -// 1. The origin of this software must not be misrepresented; you must not -// claim that you wrote the original software. If you use this software -// in a product, an acknowledgment in the product documentation would be -// appreciated but is not required. - -// 2. Altered source versions must be plainly marked as such, and must not be -// misrepresented as being the original software. - -// 3. This notice may not be removed or altered from any source -// distribution. - -#ifndef UTILS_THREAD_POOL_H_ -#define UTILS_THREAD_POOL_H_ - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -class ThreadPool { - public: - explicit ThreadPool(size_t); - template - auto enqueue(F&& f, Args&&... args) - -> std::future::type>; - ~ThreadPool(); - - private: - // need to keep track of threads so we can join them - std::vector workers; - // the task queue - std::queue > tasks; - - // synchronization - std::mutex queue_mutex; - std::condition_variable condition; - bool stop; -}; - -// the constructor just launches some amount of workers -inline ThreadPool::ThreadPool(size_t threads) : stop(false) { - for (size_t i = 0; i < threads; ++i) - workers.emplace_back([this] { - for (;;) { - std::function task; - - { - std::unique_lock lock(this->queue_mutex); - this->condition.wait( - lock, [this] { return this->stop || !this->tasks.empty(); }); - if (this->stop && this->tasks.empty()) return; - task = std::move(this->tasks.front()); - this->tasks.pop(); - } - - task(); - } - }); -} - -// add new work item to the pool -template -auto ThreadPool::enqueue(F&& f, Args&&... args) - -> std::future::type> { - using return_type = typename std::result_of::type; - - auto task = std::make_shared >( - std::bind(std::forward(f), std::forward(args)...)); - - std::future res = task->get_future(); - { - std::unique_lock lock(queue_mutex); - - // don't allow enqueueing after stopping the pool - if (stop) { - throw std::runtime_error("enqueue on stopped ThreadPool"); - } - - tasks.emplace([task]() { (*task)(); }); - } - condition.notify_one(); - return res; -} - -// the destructor joins all threads -inline ThreadPool::~ThreadPool() { - { - std::unique_lock lock(queue_mutex); - stop = true; - } - condition.notify_all(); - for (std::thread& worker : workers) { - worker.join(); - } -} - -#endif // UTILS_THREAD_POOL_H_ +// Copyright (c) 2012 Jakob Progsch, Václav Zeman + +// This software is provided 'as-is', without any express or implied +// warranty. In no event will the authors be held liable for any damages +// arising from the use of this software. + +// Permission is granted to anyone to use this software for any purpose, +// including commercial applications, and to alter it and redistribute it +// freely, subject to the following restrictions: + +// 1. The origin of this software must not be misrepresented; you must not +// claim that you wrote the original software. If you use this software +// in a product, an acknowledgment in the product documentation would be +// appreciated but is not required. + +// 2. Altered source versions must be plainly marked as such, and must not be +// misrepresented as being the original software. + +// 3. This notice may not be removed or altered from any source +// distribution. + +#ifndef UTILS_THREAD_POOL_H_ +#define UTILS_THREAD_POOL_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +class ThreadPool { + public: + explicit ThreadPool(size_t); + template + auto enqueue(F&& f, Args&&... args) + -> std::future::type>; + ~ThreadPool(); + + private: + // need to keep track of threads so we can join them + std::vector workers; + // the task queue + std::queue > tasks; + + // synchronization + std::mutex queue_mutex; + std::condition_variable condition; + bool stop; +}; + +// the constructor just launches some amount of workers +inline ThreadPool::ThreadPool(size_t threads) : stop(false) { + for (size_t i = 0; i < threads; ++i) + workers.emplace_back([this] { + for (;;) { + std::function task; + + { + std::unique_lock lock(this->queue_mutex); + this->condition.wait( + lock, [this] { return this->stop || !this->tasks.empty(); }); + if (this->stop && this->tasks.empty()) return; + task = std::move(this->tasks.front()); + this->tasks.pop(); + } + + task(); + } + }); +} + +// add new work item to the pool +template +auto ThreadPool::enqueue(F&& f, Args&&... args) + -> std::future::type> { + using return_type = typename std::result_of::type; + + auto task = std::make_shared >( + std::bind(std::forward(f), std::forward(args)...)); + + std::future res = task->get_future(); + { + std::unique_lock lock(queue_mutex); + + // don't allow enqueueing after stopping the pool + if (stop) { + throw std::runtime_error("enqueue on stopped ThreadPool"); + } + + tasks.emplace([task]() { (*task)(); }); + } + condition.notify_one(); + return res; +} + +// the destructor joins all threads +inline ThreadPool::~ThreadPool() { + { + std::unique_lock lock(queue_mutex); + stop = true; + } + condition.notify_all(); + for (std::thread& worker : workers) { + worker.join(); + } +} + +#endif // UTILS_THREAD_POOL_H_ diff --git a/models/audio/speech_recognition/conformer/igie/wenet/runtime/raspberrypi/patch/openfst/src/CMakeLists.txt b/models/audio/speech_recognition/conformer/igie/wenet/runtime/raspberrypi/patch/openfst/src/CMakeLists.txt index 04fb1c15b6c1098b61f57f3c276d1a7595b20319..04051ef5ae46c04a40c1ffccc98c37fa594ad13e 100644 --- a/models/audio/speech_recognition/conformer/igie/wenet/runtime/raspberrypi/patch/openfst/src/CMakeLists.txt +++ b/models/audio/speech_recognition/conformer/igie/wenet/runtime/raspberrypi/patch/openfst/src/CMakeLists.txt @@ -1,23 +1,23 @@ - -#-DHAVE_CONFIG_H -I./../include -fno-exceptions -funsigned-char -std=c++11 -MT symbol-table.lo -MD -MP -MF .deps/symbol-table.Tpo -c symbol-table.cc -fno-common -DPIC -o .libs/symbol-table.o - -include_directories(./include/) -install(DIRECTORY include/ DESTINATION include/ - FILES_MATCHING PATTERN "*.h") - -add_subdirectory(lib) - -if(HAVE_SCRIPT) - add_subdirectory(script) -endif(HAVE_SCRIPT) - -if(HAVE_BIN) - add_subdirectory(bin) -endif(HAVE_BIN) - -add_subdirectory(extensions) - -if(BUILD_TESTING) - enable_testing() - add_subdirectory(test) -endif(BUILD_TESTING) + +#-DHAVE_CONFIG_H -I./../include -fno-exceptions -funsigned-char -std=c++11 -MT symbol-table.lo -MD -MP -MF .deps/symbol-table.Tpo -c symbol-table.cc -fno-common -DPIC -o .libs/symbol-table.o + +include_directories(./include/) +install(DIRECTORY include/ DESTINATION include/ + FILES_MATCHING PATTERN "*.h") + +add_subdirectory(lib) + +if(HAVE_SCRIPT) + add_subdirectory(script) +endif(HAVE_SCRIPT) + +if(HAVE_BIN) + add_subdirectory(bin) +endif(HAVE_BIN) + +add_subdirectory(extensions) + +if(BUILD_TESTING) + enable_testing() + add_subdirectory(test) +endif(BUILD_TESTING) diff --git a/models/audio/speech_recognition/conformer/igie/wenet/runtime/raspberrypi/utils/thread_pool.h b/models/audio/speech_recognition/conformer/igie/wenet/runtime/raspberrypi/utils/thread_pool.h index 03cc45b8b2e4bc273706353f248ef088db0eaca8..a78162995d90bf079ad091cf14cb9f2cd4476d05 100644 --- a/models/audio/speech_recognition/conformer/igie/wenet/runtime/raspberrypi/utils/thread_pool.h +++ b/models/audio/speech_recognition/conformer/igie/wenet/runtime/raspberrypi/utils/thread_pool.h @@ -1,113 +1,113 @@ -// Copyright (c) 2012 Jakob Progsch, Václav Zeman - -// This software is provided 'as-is', without any express or implied -// warranty. In no event will the authors be held liable for any damages -// arising from the use of this software. - -// Permission is granted to anyone to use this software for any purpose, -// including commercial applications, and to alter it and redistribute it -// freely, subject to the following restrictions: - -// 1. The origin of this software must not be misrepresented; you must not -// claim that you wrote the original software. If you use this software -// in a product, an acknowledgment in the product documentation would be -// appreciated but is not required. - -// 2. Altered source versions must be plainly marked as such, and must not be -// misrepresented as being the original software. - -// 3. This notice may not be removed or altered from any source -// distribution. - -#ifndef UTILS_THREAD_POOL_H_ -#define UTILS_THREAD_POOL_H_ - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -class ThreadPool { - public: - explicit ThreadPool(size_t); - template - auto enqueue(F&& f, Args&&... args) - -> std::future::type>; - ~ThreadPool(); - - private: - // need to keep track of threads so we can join them - std::vector workers; - // the task queue - std::queue > tasks; - - // synchronization - std::mutex queue_mutex; - std::condition_variable condition; - bool stop; -}; - -// the constructor just launches some amount of workers -inline ThreadPool::ThreadPool(size_t threads) : stop(false) { - for (size_t i = 0; i < threads; ++i) - workers.emplace_back([this] { - for (;;) { - std::function task; - - { - std::unique_lock lock(this->queue_mutex); - this->condition.wait( - lock, [this] { return this->stop || !this->tasks.empty(); }); - if (this->stop && this->tasks.empty()) return; - task = std::move(this->tasks.front()); - this->tasks.pop(); - } - - task(); - } - }); -} - -// add new work item to the pool -template -auto ThreadPool::enqueue(F&& f, Args&&... args) - -> std::future::type> { - using return_type = typename std::result_of::type; - - auto task = std::make_shared >( - std::bind(std::forward(f), std::forward(args)...)); - - std::future res = task->get_future(); - { - std::unique_lock lock(queue_mutex); - - // don't allow enqueueing after stopping the pool - if (stop) { - throw std::runtime_error("enqueue on stopped ThreadPool"); - } - - tasks.emplace([task]() { (*task)(); }); - } - condition.notify_one(); - return res; -} - -// the destructor joins all threads -inline ThreadPool::~ThreadPool() { - { - std::unique_lock lock(queue_mutex); - stop = true; - } - condition.notify_all(); - for (std::thread& worker : workers) { - worker.join(); - } -} - -#endif // UTILS_THREAD_POOL_H_ +// Copyright (c) 2012 Jakob Progsch, Václav Zeman + +// This software is provided 'as-is', without any express or implied +// warranty. In no event will the authors be held liable for any damages +// arising from the use of this software. + +// Permission is granted to anyone to use this software for any purpose, +// including commercial applications, and to alter it and redistribute it +// freely, subject to the following restrictions: + +// 1. The origin of this software must not be misrepresented; you must not +// claim that you wrote the original software. If you use this software +// in a product, an acknowledgment in the product documentation would be +// appreciated but is not required. + +// 2. Altered source versions must be plainly marked as such, and must not be +// misrepresented as being the original software. + +// 3. This notice may not be removed or altered from any source +// distribution. + +#ifndef UTILS_THREAD_POOL_H_ +#define UTILS_THREAD_POOL_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +class ThreadPool { + public: + explicit ThreadPool(size_t); + template + auto enqueue(F&& f, Args&&... args) + -> std::future::type>; + ~ThreadPool(); + + private: + // need to keep track of threads so we can join them + std::vector workers; + // the task queue + std::queue > tasks; + + // synchronization + std::mutex queue_mutex; + std::condition_variable condition; + bool stop; +}; + +// the constructor just launches some amount of workers +inline ThreadPool::ThreadPool(size_t threads) : stop(false) { + for (size_t i = 0; i < threads; ++i) + workers.emplace_back([this] { + for (;;) { + std::function task; + + { + std::unique_lock lock(this->queue_mutex); + this->condition.wait( + lock, [this] { return this->stop || !this->tasks.empty(); }); + if (this->stop && this->tasks.empty()) return; + task = std::move(this->tasks.front()); + this->tasks.pop(); + } + + task(); + } + }); +} + +// add new work item to the pool +template +auto ThreadPool::enqueue(F&& f, Args&&... args) + -> std::future::type> { + using return_type = typename std::result_of::type; + + auto task = std::make_shared >( + std::bind(std::forward(f), std::forward(args)...)); + + std::future res = task->get_future(); + { + std::unique_lock lock(queue_mutex); + + // don't allow enqueueing after stopping the pool + if (stop) { + throw std::runtime_error("enqueue on stopped ThreadPool"); + } + + tasks.emplace([task]() { (*task)(); }); + } + condition.notify_one(); + return res; +} + +// the destructor joins all threads +inline ThreadPool::~ThreadPool() { + { + std::unique_lock lock(queue_mutex); + stop = true; + } + condition.notify_all(); + for (std::thread& worker : workers) { + worker.join(); + } +} + +#endif // UTILS_THREAD_POOL_H_ diff --git a/models/cv/classification/alexnet/igie/requirements.txt b/models/cv/classification/alexnet/igie/requirements.txt index 4c1d32d3df61899847e6d50edd8f8a42470c63f4..9e8111264d4bb2c985cdd10c1de3b894d4e50bef 100644 --- a/models/cv/classification/alexnet/igie/requirements.txt +++ b/models/cv/classification/alexnet/igie/requirements.txt @@ -1,2 +1,2 @@ -onnx -tqdm +onnx +tqdm diff --git a/models/cv/classification/clip/igie/requirements.txt b/models/cv/classification/clip/igie/requirements.txt index a65219de4c04c72e5c7a76a46dea9fd60a287d91..b3cc1fc6b995d97ae345ad9c6eed9f1581fd1f77 100644 --- a/models/cv/classification/clip/igie/requirements.txt +++ b/models/cv/classification/clip/igie/requirements.txt @@ -1,3 +1,3 @@ -tqdm -onnxsim -transformers==4.33.2 +tqdm +onnxsim +transformers==4.33.2 diff --git a/models/cv/classification/conformer_base/igie/requirements.txt b/models/cv/classification/conformer_base/igie/requirements.txt index ed69cff2ac668b70f0875f53f7935aa3709aeef0..f36d0c071373f7e81dce31bfca69a600e7fd5ca1 100644 --- a/models/cv/classification/conformer_base/igie/requirements.txt +++ b/models/cv/classification/conformer_base/igie/requirements.txt @@ -1,4 +1,4 @@ -onnx -tqdm -timm -onnxsim +onnx +tqdm +timm +onnxsim diff --git a/models/cv/classification/convnext_base/igie/requirements.txt b/models/cv/classification/convnext_base/igie/requirements.txt index 4c1d32d3df61899847e6d50edd8f8a42470c63f4..9e8111264d4bb2c985cdd10c1de3b894d4e50bef 100644 --- a/models/cv/classification/convnext_base/igie/requirements.txt +++ b/models/cv/classification/convnext_base/igie/requirements.txt @@ -1,2 +1,2 @@ -onnx -tqdm +onnx +tqdm diff --git a/models/cv/classification/convnext_s/igie/requirements.txt b/models/cv/classification/convnext_s/igie/requirements.txt index 63629c9686ef1b1fb60ccadc4e5ba2799b693f89..41c3166395b56ce698ec11c8f3aef19624cae2bb 100644 --- a/models/cv/classification/convnext_s/igie/requirements.txt +++ b/models/cv/classification/convnext_s/igie/requirements.txt @@ -1,5 +1,5 @@ -onnx -tqdm -onnxsim -mmcv==1.5.3 -mmcls +onnx +tqdm +onnxsim +mmcv==1.5.3 +mmcls diff --git a/models/cv/classification/convnext_small/igie/requirements.txt b/models/cv/classification/convnext_small/igie/requirements.txt index 4c1d32d3df61899847e6d50edd8f8a42470c63f4..9e8111264d4bb2c985cdd10c1de3b894d4e50bef 100644 --- a/models/cv/classification/convnext_small/igie/requirements.txt +++ b/models/cv/classification/convnext_small/igie/requirements.txt @@ -1,2 +1,2 @@ -onnx -tqdm +onnx +tqdm diff --git a/models/cv/classification/cspdarknet53/igie/requirements.txt b/models/cv/classification/cspdarknet53/igie/requirements.txt index 63629c9686ef1b1fb60ccadc4e5ba2799b693f89..41c3166395b56ce698ec11c8f3aef19624cae2bb 100644 --- a/models/cv/classification/cspdarknet53/igie/requirements.txt +++ b/models/cv/classification/cspdarknet53/igie/requirements.txt @@ -1,5 +1,5 @@ -onnx -tqdm -onnxsim -mmcv==1.5.3 -mmcls +onnx +tqdm +onnxsim +mmcv==1.5.3 +mmcls diff --git a/models/cv/classification/cspresnet50/igie/requirements.txt b/models/cv/classification/cspresnet50/igie/requirements.txt index 63629c9686ef1b1fb60ccadc4e5ba2799b693f89..41c3166395b56ce698ec11c8f3aef19624cae2bb 100644 --- a/models/cv/classification/cspresnet50/igie/requirements.txt +++ b/models/cv/classification/cspresnet50/igie/requirements.txt @@ -1,5 +1,5 @@ -onnx -tqdm -onnxsim -mmcv==1.5.3 -mmcls +onnx +tqdm +onnxsim +mmcv==1.5.3 +mmcls diff --git a/models/cv/classification/deit_tiny/igie/requirements.txt b/models/cv/classification/deit_tiny/igie/requirements.txt index 63629c9686ef1b1fb60ccadc4e5ba2799b693f89..41c3166395b56ce698ec11c8f3aef19624cae2bb 100644 --- a/models/cv/classification/deit_tiny/igie/requirements.txt +++ b/models/cv/classification/deit_tiny/igie/requirements.txt @@ -1,5 +1,5 @@ -onnx -tqdm -onnxsim -mmcv==1.5.3 -mmcls +onnx +tqdm +onnxsim +mmcv==1.5.3 +mmcls diff --git a/models/cv/classification/densenet121/igie/requirements.txt b/models/cv/classification/densenet121/igie/requirements.txt index 4c1d32d3df61899847e6d50edd8f8a42470c63f4..9e8111264d4bb2c985cdd10c1de3b894d4e50bef 100644 --- a/models/cv/classification/densenet121/igie/requirements.txt +++ b/models/cv/classification/densenet121/igie/requirements.txt @@ -1,2 +1,2 @@ -onnx -tqdm +onnx +tqdm diff --git a/models/cv/classification/densenet161/igie/requirements.txt b/models/cv/classification/densenet161/igie/requirements.txt index 4c1d32d3df61899847e6d50edd8f8a42470c63f4..9e8111264d4bb2c985cdd10c1de3b894d4e50bef 100644 --- a/models/cv/classification/densenet161/igie/requirements.txt +++ b/models/cv/classification/densenet161/igie/requirements.txt @@ -1,2 +1,2 @@ -onnx -tqdm +onnx +tqdm diff --git a/models/cv/classification/densenet169/igie/requirements.txt b/models/cv/classification/densenet169/igie/requirements.txt index 4c1d32d3df61899847e6d50edd8f8a42470c63f4..9e8111264d4bb2c985cdd10c1de3b894d4e50bef 100644 --- a/models/cv/classification/densenet169/igie/requirements.txt +++ b/models/cv/classification/densenet169/igie/requirements.txt @@ -1,2 +1,2 @@ -onnx -tqdm +onnx +tqdm diff --git a/models/cv/classification/densenet201/igie/requirements.txt b/models/cv/classification/densenet201/igie/requirements.txt index 4c1d32d3df61899847e6d50edd8f8a42470c63f4..9e8111264d4bb2c985cdd10c1de3b894d4e50bef 100644 --- a/models/cv/classification/densenet201/igie/requirements.txt +++ b/models/cv/classification/densenet201/igie/requirements.txt @@ -1,2 +1,2 @@ -onnx -tqdm +onnx +tqdm diff --git a/models/cv/classification/efficientnet_b0/igie/requirements.txt b/models/cv/classification/efficientnet_b0/igie/requirements.txt index 4c1d32d3df61899847e6d50edd8f8a42470c63f4..9e8111264d4bb2c985cdd10c1de3b894d4e50bef 100644 --- a/models/cv/classification/efficientnet_b0/igie/requirements.txt +++ b/models/cv/classification/efficientnet_b0/igie/requirements.txt @@ -1,2 +1,2 @@ -onnx -tqdm +onnx +tqdm diff --git a/models/cv/classification/efficientnet_b1/igie/requirements.txt b/models/cv/classification/efficientnet_b1/igie/requirements.txt index 4c1d32d3df61899847e6d50edd8f8a42470c63f4..9e8111264d4bb2c985cdd10c1de3b894d4e50bef 100644 --- a/models/cv/classification/efficientnet_b1/igie/requirements.txt +++ b/models/cv/classification/efficientnet_b1/igie/requirements.txt @@ -1,2 +1,2 @@ -onnx -tqdm +onnx +tqdm diff --git a/models/cv/classification/efficientnet_b2/igie/requirements.txt b/models/cv/classification/efficientnet_b2/igie/requirements.txt index 4c1d32d3df61899847e6d50edd8f8a42470c63f4..9e8111264d4bb2c985cdd10c1de3b894d4e50bef 100644 --- a/models/cv/classification/efficientnet_b2/igie/requirements.txt +++ b/models/cv/classification/efficientnet_b2/igie/requirements.txt @@ -1,2 +1,2 @@ -onnx -tqdm +onnx +tqdm diff --git a/models/cv/classification/efficientnet_b3/igie/requirements.txt b/models/cv/classification/efficientnet_b3/igie/requirements.txt index 4c1d32d3df61899847e6d50edd8f8a42470c63f4..9e8111264d4bb2c985cdd10c1de3b894d4e50bef 100644 --- a/models/cv/classification/efficientnet_b3/igie/requirements.txt +++ b/models/cv/classification/efficientnet_b3/igie/requirements.txt @@ -1,2 +1,2 @@ -onnx -tqdm +onnx +tqdm diff --git a/models/cv/classification/efficientnet_b4/igie/requirements.txt b/models/cv/classification/efficientnet_b4/igie/requirements.txt index 4c1d32d3df61899847e6d50edd8f8a42470c63f4..9e8111264d4bb2c985cdd10c1de3b894d4e50bef 100644 --- a/models/cv/classification/efficientnet_b4/igie/requirements.txt +++ b/models/cv/classification/efficientnet_b4/igie/requirements.txt @@ -1,2 +1,2 @@ -onnx -tqdm +onnx +tqdm diff --git a/models/cv/classification/efficientnet_v2/igie/requirements.txt b/models/cv/classification/efficientnet_v2/igie/requirements.txt index 4c1d32d3df61899847e6d50edd8f8a42470c63f4..9e8111264d4bb2c985cdd10c1de3b894d4e50bef 100644 --- a/models/cv/classification/efficientnet_v2/igie/requirements.txt +++ b/models/cv/classification/efficientnet_v2/igie/requirements.txt @@ -1,2 +1,2 @@ -onnx -tqdm +onnx +tqdm diff --git a/models/cv/classification/efficientnet_v2_s/igie/requirements.txt b/models/cv/classification/efficientnet_v2_s/igie/requirements.txt index 4c1d32d3df61899847e6d50edd8f8a42470c63f4..9e8111264d4bb2c985cdd10c1de3b894d4e50bef 100644 --- a/models/cv/classification/efficientnet_v2_s/igie/requirements.txt +++ b/models/cv/classification/efficientnet_v2_s/igie/requirements.txt @@ -1,2 +1,2 @@ -onnx -tqdm +onnx +tqdm diff --git a/models/cv/classification/efficientnet_v2_s/ixrt/requirements.txt b/models/cv/classification/efficientnet_v2_s/ixrt/requirements.txt index 4c1d32d3df61899847e6d50edd8f8a42470c63f4..9e8111264d4bb2c985cdd10c1de3b894d4e50bef 100644 --- a/models/cv/classification/efficientnet_v2_s/ixrt/requirements.txt +++ b/models/cv/classification/efficientnet_v2_s/ixrt/requirements.txt @@ -1,2 +1,2 @@ -onnx -tqdm +onnx +tqdm diff --git a/models/cv/classification/efficientnetv2_rw_t/igie/requirements.txt b/models/cv/classification/efficientnetv2_rw_t/igie/requirements.txt index 3b2080776ddab7648ecbad549b7262579c29f0d5..36677a29ab3a81e04e55e2185513580169404d15 100644 --- a/models/cv/classification/efficientnetv2_rw_t/igie/requirements.txt +++ b/models/cv/classification/efficientnetv2_rw_t/igie/requirements.txt @@ -1,3 +1,3 @@ -timm -onnx -tqdm +timm +onnx +tqdm diff --git a/models/cv/classification/googlenet/igie/requirements.txt b/models/cv/classification/googlenet/igie/requirements.txt index 4c1d32d3df61899847e6d50edd8f8a42470c63f4..9e8111264d4bb2c985cdd10c1de3b894d4e50bef 100644 --- a/models/cv/classification/googlenet/igie/requirements.txt +++ b/models/cv/classification/googlenet/igie/requirements.txt @@ -1,2 +1,2 @@ -onnx -tqdm +onnx +tqdm diff --git a/models/cv/classification/hrnet_w18/igie/requirements.txt b/models/cv/classification/hrnet_w18/igie/requirements.txt index 63629c9686ef1b1fb60ccadc4e5ba2799b693f89..41c3166395b56ce698ec11c8f3aef19624cae2bb 100644 --- a/models/cv/classification/hrnet_w18/igie/requirements.txt +++ b/models/cv/classification/hrnet_w18/igie/requirements.txt @@ -1,5 +1,5 @@ -onnx -tqdm -onnxsim -mmcv==1.5.3 -mmcls +onnx +tqdm +onnxsim +mmcv==1.5.3 +mmcls diff --git a/models/cv/classification/inception_v3/igie/requirements.txt b/models/cv/classification/inception_v3/igie/requirements.txt index 4c1d32d3df61899847e6d50edd8f8a42470c63f4..9e8111264d4bb2c985cdd10c1de3b894d4e50bef 100644 --- a/models/cv/classification/inception_v3/igie/requirements.txt +++ b/models/cv/classification/inception_v3/igie/requirements.txt @@ -1,2 +1,2 @@ -onnx -tqdm +onnx +tqdm diff --git a/models/cv/classification/mlp_mixer_base/igie/requirements.txt b/models/cv/classification/mlp_mixer_base/igie/requirements.txt index 63629c9686ef1b1fb60ccadc4e5ba2799b693f89..41c3166395b56ce698ec11c8f3aef19624cae2bb 100644 --- a/models/cv/classification/mlp_mixer_base/igie/requirements.txt +++ b/models/cv/classification/mlp_mixer_base/igie/requirements.txt @@ -1,5 +1,5 @@ -onnx -tqdm -onnxsim -mmcv==1.5.3 -mmcls +onnx +tqdm +onnxsim +mmcv==1.5.3 +mmcls diff --git a/models/cv/classification/mnasnet0_5/igie/requirements.txt b/models/cv/classification/mnasnet0_5/igie/requirements.txt index 4c1d32d3df61899847e6d50edd8f8a42470c63f4..9e8111264d4bb2c985cdd10c1de3b894d4e50bef 100644 --- a/models/cv/classification/mnasnet0_5/igie/requirements.txt +++ b/models/cv/classification/mnasnet0_5/igie/requirements.txt @@ -1,2 +1,2 @@ -onnx -tqdm +onnx +tqdm diff --git a/models/cv/classification/mnasnet0_75/igie/requirements.txt b/models/cv/classification/mnasnet0_75/igie/requirements.txt index 4c1d32d3df61899847e6d50edd8f8a42470c63f4..9e8111264d4bb2c985cdd10c1de3b894d4e50bef 100644 --- a/models/cv/classification/mnasnet0_75/igie/requirements.txt +++ b/models/cv/classification/mnasnet0_75/igie/requirements.txt @@ -1,2 +1,2 @@ -onnx -tqdm +onnx +tqdm diff --git a/models/cv/classification/mobilenet_v2/igie/requirements.txt b/models/cv/classification/mobilenet_v2/igie/requirements.txt index 5c52b7c9d257ee9259cb04e3344e12a606e921c2..08a0a972e3bf75e58f513f312b12d56d7a38d3b6 100644 --- a/models/cv/classification/mobilenet_v2/igie/requirements.txt +++ b/models/cv/classification/mobilenet_v2/igie/requirements.txt @@ -1,3 +1,3 @@ -onnx -tqdm +onnx +tqdm onnxruntime-gpu==1.18.0 \ No newline at end of file diff --git a/models/cv/classification/mobilenet_v3/igie/requirements.txt b/models/cv/classification/mobilenet_v3/igie/requirements.txt index 4c1d32d3df61899847e6d50edd8f8a42470c63f4..9e8111264d4bb2c985cdd10c1de3b894d4e50bef 100644 --- a/models/cv/classification/mobilenet_v3/igie/requirements.txt +++ b/models/cv/classification/mobilenet_v3/igie/requirements.txt @@ -1,2 +1,2 @@ -onnx -tqdm +onnx +tqdm diff --git a/models/cv/classification/mvitv2_base/igie/requirements.txt b/models/cv/classification/mvitv2_base/igie/requirements.txt index 63629c9686ef1b1fb60ccadc4e5ba2799b693f89..41c3166395b56ce698ec11c8f3aef19624cae2bb 100644 --- a/models/cv/classification/mvitv2_base/igie/requirements.txt +++ b/models/cv/classification/mvitv2_base/igie/requirements.txt @@ -1,5 +1,5 @@ -onnx -tqdm -onnxsim -mmcv==1.5.3 -mmcls +onnx +tqdm +onnxsim +mmcv==1.5.3 +mmcls diff --git a/models/cv/classification/regnet_x_16gf/igie/requirements.txt b/models/cv/classification/regnet_x_16gf/igie/requirements.txt index 4c1d32d3df61899847e6d50edd8f8a42470c63f4..9e8111264d4bb2c985cdd10c1de3b894d4e50bef 100644 --- a/models/cv/classification/regnet_x_16gf/igie/requirements.txt +++ b/models/cv/classification/regnet_x_16gf/igie/requirements.txt @@ -1,2 +1,2 @@ -onnx -tqdm +onnx +tqdm diff --git a/models/cv/classification/regnet_x_1_6gf/igie/requirements.txt b/models/cv/classification/regnet_x_1_6gf/igie/requirements.txt index 4c1d32d3df61899847e6d50edd8f8a42470c63f4..9e8111264d4bb2c985cdd10c1de3b894d4e50bef 100644 --- a/models/cv/classification/regnet_x_1_6gf/igie/requirements.txt +++ b/models/cv/classification/regnet_x_1_6gf/igie/requirements.txt @@ -1,2 +1,2 @@ -onnx -tqdm +onnx +tqdm diff --git a/models/cv/classification/regnet_y_1_6gf/igie/requirements.txt b/models/cv/classification/regnet_y_1_6gf/igie/requirements.txt index 4c1d32d3df61899847e6d50edd8f8a42470c63f4..9e8111264d4bb2c985cdd10c1de3b894d4e50bef 100644 --- a/models/cv/classification/regnet_y_1_6gf/igie/requirements.txt +++ b/models/cv/classification/regnet_y_1_6gf/igie/requirements.txt @@ -1,2 +1,2 @@ -onnx -tqdm +onnx +tqdm diff --git a/models/cv/classification/repvgg/igie/requirements.txt b/models/cv/classification/repvgg/igie/requirements.txt index 637eab19e8101e93942e8177d7739ce2ee6cb794..2bd089c8fdcd7755d0ca576c42f26d79e2e7c433 100644 --- a/models/cv/classification/repvgg/igie/requirements.txt +++ b/models/cv/classification/repvgg/igie/requirements.txt @@ -1,5 +1,5 @@ -onnx -tqdm -mmcv==1.5.3 -mmcls -mmengine +onnx +tqdm +mmcv==1.5.3 +mmcls +mmengine diff --git a/models/cv/classification/res2net50/igie/requirements.txt b/models/cv/classification/res2net50/igie/requirements.txt index 63629c9686ef1b1fb60ccadc4e5ba2799b693f89..41c3166395b56ce698ec11c8f3aef19624cae2bb 100644 --- a/models/cv/classification/res2net50/igie/requirements.txt +++ b/models/cv/classification/res2net50/igie/requirements.txt @@ -1,5 +1,5 @@ -onnx -tqdm -onnxsim -mmcv==1.5.3 -mmcls +onnx +tqdm +onnxsim +mmcv==1.5.3 +mmcls diff --git a/models/cv/classification/resnest50/igie/requirements.txt b/models/cv/classification/resnest50/igie/requirements.txt index d337b1c18b6f129e0e0d8c3e1d9bfe6b41db55fc..b6cbf88fb82dea30f99d75e35623ebbb72eeeaa7 100644 --- a/models/cv/classification/resnest50/igie/requirements.txt +++ b/models/cv/classification/resnest50/igie/requirements.txt @@ -1,4 +1,4 @@ -onnx -tqdm -onnxsim -git+https://github.com/zhanghang1989/ResNeSt +onnx +tqdm +onnxsim +git+https://github.com/zhanghang1989/ResNeSt diff --git a/models/cv/classification/resnet101/igie/requirements.txt b/models/cv/classification/resnet101/igie/requirements.txt index 4c1d32d3df61899847e6d50edd8f8a42470c63f4..9e8111264d4bb2c985cdd10c1de3b894d4e50bef 100644 --- a/models/cv/classification/resnet101/igie/requirements.txt +++ b/models/cv/classification/resnet101/igie/requirements.txt @@ -1,2 +1,2 @@ -onnx -tqdm +onnx +tqdm diff --git a/models/cv/classification/resnet152/igie/requirements.txt b/models/cv/classification/resnet152/igie/requirements.txt index 4c1d32d3df61899847e6d50edd8f8a42470c63f4..9e8111264d4bb2c985cdd10c1de3b894d4e50bef 100644 --- a/models/cv/classification/resnet152/igie/requirements.txt +++ b/models/cv/classification/resnet152/igie/requirements.txt @@ -1,2 +1,2 @@ -onnx -tqdm +onnx +tqdm diff --git a/models/cv/classification/resnet18/igie/requirements.txt b/models/cv/classification/resnet18/igie/requirements.txt index 4c1d32d3df61899847e6d50edd8f8a42470c63f4..9e8111264d4bb2c985cdd10c1de3b894d4e50bef 100644 --- a/models/cv/classification/resnet18/igie/requirements.txt +++ b/models/cv/classification/resnet18/igie/requirements.txt @@ -1,2 +1,2 @@ -onnx -tqdm +onnx +tqdm diff --git a/models/cv/classification/resnet50/igie/requirements.txt b/models/cv/classification/resnet50/igie/requirements.txt index 4c1d32d3df61899847e6d50edd8f8a42470c63f4..9e8111264d4bb2c985cdd10c1de3b894d4e50bef 100644 --- a/models/cv/classification/resnet50/igie/requirements.txt +++ b/models/cv/classification/resnet50/igie/requirements.txt @@ -1,2 +1,2 @@ -onnx -tqdm +onnx +tqdm diff --git a/models/cv/classification/resnetv1d50/igie/requirements.txt b/models/cv/classification/resnetv1d50/igie/requirements.txt index 266a11e3e679bc51cac818ce66c35366f8c3a62c..4d5ea05fc6b67b4fd9274544384d4473c2994a5b 100644 --- a/models/cv/classification/resnetv1d50/igie/requirements.txt +++ b/models/cv/classification/resnetv1d50/igie/requirements.txt @@ -1,4 +1,4 @@ -onnx -tqdm -mmcv==1.5.3 -mmcls +onnx +tqdm +mmcv==1.5.3 +mmcls diff --git a/models/cv/classification/resnext101_32x8d/igie/requirements.txt b/models/cv/classification/resnext101_32x8d/igie/requirements.txt index 4c1d32d3df61899847e6d50edd8f8a42470c63f4..9e8111264d4bb2c985cdd10c1de3b894d4e50bef 100644 --- a/models/cv/classification/resnext101_32x8d/igie/requirements.txt +++ b/models/cv/classification/resnext101_32x8d/igie/requirements.txt @@ -1,2 +1,2 @@ -onnx -tqdm +onnx +tqdm diff --git a/models/cv/classification/resnext101_64x4d/igie/requirements.txt b/models/cv/classification/resnext101_64x4d/igie/requirements.txt index 4c1d32d3df61899847e6d50edd8f8a42470c63f4..9e8111264d4bb2c985cdd10c1de3b894d4e50bef 100644 --- a/models/cv/classification/resnext101_64x4d/igie/requirements.txt +++ b/models/cv/classification/resnext101_64x4d/igie/requirements.txt @@ -1,2 +1,2 @@ -onnx -tqdm +onnx +tqdm diff --git a/models/cv/classification/resnext50_32x4d/igie/requirements.txt b/models/cv/classification/resnext50_32x4d/igie/requirements.txt index 4c1d32d3df61899847e6d50edd8f8a42470c63f4..9e8111264d4bb2c985cdd10c1de3b894d4e50bef 100644 --- a/models/cv/classification/resnext50_32x4d/igie/requirements.txt +++ b/models/cv/classification/resnext50_32x4d/igie/requirements.txt @@ -1,2 +1,2 @@ -onnx -tqdm +onnx +tqdm diff --git a/models/cv/classification/shufflenetv2_x0_5/igie/requirements.txt b/models/cv/classification/shufflenetv2_x0_5/igie/requirements.txt index 4c1d32d3df61899847e6d50edd8f8a42470c63f4..9e8111264d4bb2c985cdd10c1de3b894d4e50bef 100644 --- a/models/cv/classification/shufflenetv2_x0_5/igie/requirements.txt +++ b/models/cv/classification/shufflenetv2_x0_5/igie/requirements.txt @@ -1,2 +1,2 @@ -onnx -tqdm +onnx +tqdm diff --git a/models/cv/classification/shufflenetv2_x1_0/igie/requirements.txt b/models/cv/classification/shufflenetv2_x1_0/igie/requirements.txt index 4c1d32d3df61899847e6d50edd8f8a42470c63f4..9e8111264d4bb2c985cdd10c1de3b894d4e50bef 100644 --- a/models/cv/classification/shufflenetv2_x1_0/igie/requirements.txt +++ b/models/cv/classification/shufflenetv2_x1_0/igie/requirements.txt @@ -1,2 +1,2 @@ -onnx -tqdm +onnx +tqdm diff --git a/models/cv/classification/shufflenetv2_x1_5/igie/requirements.txt b/models/cv/classification/shufflenetv2_x1_5/igie/requirements.txt index 4c1d32d3df61899847e6d50edd8f8a42470c63f4..9e8111264d4bb2c985cdd10c1de3b894d4e50bef 100644 --- a/models/cv/classification/shufflenetv2_x1_5/igie/requirements.txt +++ b/models/cv/classification/shufflenetv2_x1_5/igie/requirements.txt @@ -1,2 +1,2 @@ -onnx -tqdm +onnx +tqdm diff --git a/models/cv/classification/shufflenetv2_x2_0/igie/requirements.txt b/models/cv/classification/shufflenetv2_x2_0/igie/requirements.txt index 4c1d32d3df61899847e6d50edd8f8a42470c63f4..9e8111264d4bb2c985cdd10c1de3b894d4e50bef 100644 --- a/models/cv/classification/shufflenetv2_x2_0/igie/requirements.txt +++ b/models/cv/classification/shufflenetv2_x2_0/igie/requirements.txt @@ -1,2 +1,2 @@ -onnx -tqdm +onnx +tqdm diff --git a/models/cv/classification/squeezenet_v1_0/igie/requirements.txt b/models/cv/classification/squeezenet_v1_0/igie/requirements.txt index 4c1d32d3df61899847e6d50edd8f8a42470c63f4..9e8111264d4bb2c985cdd10c1de3b894d4e50bef 100644 --- a/models/cv/classification/squeezenet_v1_0/igie/requirements.txt +++ b/models/cv/classification/squeezenet_v1_0/igie/requirements.txt @@ -1,2 +1,2 @@ -onnx -tqdm +onnx +tqdm diff --git a/models/cv/classification/svt_base/igie/requirements.txt b/models/cv/classification/svt_base/igie/requirements.txt index 63629c9686ef1b1fb60ccadc4e5ba2799b693f89..41c3166395b56ce698ec11c8f3aef19624cae2bb 100644 --- a/models/cv/classification/svt_base/igie/requirements.txt +++ b/models/cv/classification/svt_base/igie/requirements.txt @@ -1,5 +1,5 @@ -onnx -tqdm -onnxsim -mmcv==1.5.3 -mmcls +onnx +tqdm +onnxsim +mmcv==1.5.3 +mmcls diff --git a/models/cv/classification/swin_transformer/igie/requirements.txt b/models/cv/classification/swin_transformer/igie/requirements.txt index a65219de4c04c72e5c7a76a46dea9fd60a287d91..b3cc1fc6b995d97ae345ad9c6eed9f1581fd1f77 100644 --- a/models/cv/classification/swin_transformer/igie/requirements.txt +++ b/models/cv/classification/swin_transformer/igie/requirements.txt @@ -1,3 +1,3 @@ -tqdm -onnxsim -transformers==4.33.2 +tqdm +onnxsim +transformers==4.33.2 diff --git a/models/cv/classification/vgg11/igie/requirements.txt b/models/cv/classification/vgg11/igie/requirements.txt index 4c1d32d3df61899847e6d50edd8f8a42470c63f4..9e8111264d4bb2c985cdd10c1de3b894d4e50bef 100644 --- a/models/cv/classification/vgg11/igie/requirements.txt +++ b/models/cv/classification/vgg11/igie/requirements.txt @@ -1,2 +1,2 @@ -onnx -tqdm +onnx +tqdm diff --git a/models/cv/classification/vgg16/igie/requirements.txt b/models/cv/classification/vgg16/igie/requirements.txt index 4c1d32d3df61899847e6d50edd8f8a42470c63f4..9e8111264d4bb2c985cdd10c1de3b894d4e50bef 100644 --- a/models/cv/classification/vgg16/igie/requirements.txt +++ b/models/cv/classification/vgg16/igie/requirements.txt @@ -1,2 +1,2 @@ -onnx -tqdm +onnx +tqdm diff --git a/models/cv/classification/wide_resnet101/igie/requirements.txt b/models/cv/classification/wide_resnet101/igie/requirements.txt index 4c1d32d3df61899847e6d50edd8f8a42470c63f4..9e8111264d4bb2c985cdd10c1de3b894d4e50bef 100644 --- a/models/cv/classification/wide_resnet101/igie/requirements.txt +++ b/models/cv/classification/wide_resnet101/igie/requirements.txt @@ -1,2 +1,2 @@ -onnx -tqdm +onnx +tqdm diff --git a/models/cv/classification/wide_resnet50/igie/requirements.txt b/models/cv/classification/wide_resnet50/igie/requirements.txt index 4c1d32d3df61899847e6d50edd8f8a42470c63f4..9e8111264d4bb2c985cdd10c1de3b894d4e50bef 100644 --- a/models/cv/classification/wide_resnet50/igie/requirements.txt +++ b/models/cv/classification/wide_resnet50/igie/requirements.txt @@ -1,2 +1,2 @@ -onnx -tqdm +onnx +tqdm diff --git a/models/cv/multi_object_tracking/deepsort/igie/requirements.txt b/models/cv/multi_object_tracking/deepsort/igie/requirements.txt index e734bc4afbf24eb004e1a20719c1f6b99920da0b..0516c70db6c22f6bd9cad91f0f3ca73f51bb0bea 100644 --- a/models/cv/multi_object_tracking/deepsort/igie/requirements.txt +++ b/models/cv/multi_object_tracking/deepsort/igie/requirements.txt @@ -1,3 +1,3 @@ -onnx -tqdm -onnxsim +onnx +tqdm +onnxsim diff --git a/models/cv/multi_object_tracking/fastreid/igie/requirements.txt b/models/cv/multi_object_tracking/fastreid/igie/requirements.txt index 10d32c0fcdb64a7d6265453178bef2421bf7413d..b67b3134e04c820d471ff6f17447cdb01fdd2b52 100644 --- a/models/cv/multi_object_tracking/fastreid/igie/requirements.txt +++ b/models/cv/multi_object_tracking/fastreid/igie/requirements.txt @@ -1,4 +1,4 @@ -onnx -tqdm -onnxsim +onnx +tqdm +onnxsim onnxoptimizer \ No newline at end of file diff --git a/models/cv/multi_object_tracking/repnet/igie/requirements.txt b/models/cv/multi_object_tracking/repnet/igie/requirements.txt index e734bc4afbf24eb004e1a20719c1f6b99920da0b..0516c70db6c22f6bd9cad91f0f3ca73f51bb0bea 100644 --- a/models/cv/multi_object_tracking/repnet/igie/requirements.txt +++ b/models/cv/multi_object_tracking/repnet/igie/requirements.txt @@ -1,3 +1,3 @@ -onnx -tqdm -onnxsim +onnx +tqdm +onnxsim diff --git a/models/cv/object_detection/atss/igie/requirements.txt b/models/cv/object_detection/atss/igie/requirements.txt index a26706ef5402ca820ed6e4ab952d876ec768b4eb..b6b3fff4aa66c7401a67b2874e40c3caf154a34d 100644 --- a/models/cv/object_detection/atss/igie/requirements.txt +++ b/models/cv/object_detection/atss/igie/requirements.txt @@ -1,6 +1,6 @@ -onnx -tqdm -onnxsim -mmdet==3.3.0 -mmdeploy==1.3.1 -mmengine==0.10.4 +onnx +tqdm +onnxsim +mmdet==3.3.0 +mmdeploy==1.3.1 +mmengine==0.10.4 diff --git a/models/cv/object_detection/centernet/igie/requirements.txt b/models/cv/object_detection/centernet/igie/requirements.txt index f7e7fc9a5592f4edd681a41e340dcbec21d9a40b..71ef3c22fb93b30a8a56bfed688123b3e89ac26a 100644 --- a/models/cv/object_detection/centernet/igie/requirements.txt +++ b/models/cv/object_detection/centernet/igie/requirements.txt @@ -1,5 +1,5 @@ -onnx -tqdm -mmdet -mmdeploy -mmengine +onnx +tqdm +mmdet +mmdeploy +mmengine diff --git a/models/cv/object_detection/centernet/ixrt/requirements.txt b/models/cv/object_detection/centernet/ixrt/requirements.txt index 9178d0b61aa4155c1effbf468da32f2b8dd9f96d..291a7172f463a5a6759f5d624502572d71469d55 100644 --- a/models/cv/object_detection/centernet/ixrt/requirements.txt +++ b/models/cv/object_detection/centernet/ixrt/requirements.txt @@ -1,6 +1,6 @@ -onnx -tqdm -mmdet -mmdeploy -mmengine +onnx +tqdm +mmdet +mmdeploy +mmengine transformers==4.37.1 \ No newline at end of file diff --git a/models/cv/object_detection/detr/ixrt/requirements.txt b/models/cv/object_detection/detr/ixrt/requirements.txt index 967c8c6817fad277f415780c69ab80f53edd54c0..276602c39b221e8d5b0d24cd3b456322b1b36ec7 100644 --- a/models/cv/object_detection/detr/ixrt/requirements.txt +++ b/models/cv/object_detection/detr/ixrt/requirements.txt @@ -1,7 +1,7 @@ -tqdm -pycuda -onnx -onnxsim -tabulate -pycocotools +tqdm +pycuda +onnx +onnxsim +tabulate +pycocotools opencv-python==4.6.0.66 \ No newline at end of file diff --git a/models/cv/object_detection/fcos/igie/requirements.txt b/models/cv/object_detection/fcos/igie/requirements.txt index a26706ef5402ca820ed6e4ab952d876ec768b4eb..b6b3fff4aa66c7401a67b2874e40c3caf154a34d 100644 --- a/models/cv/object_detection/fcos/igie/requirements.txt +++ b/models/cv/object_detection/fcos/igie/requirements.txt @@ -1,6 +1,6 @@ -onnx -tqdm -onnxsim -mmdet==3.3.0 -mmdeploy==1.3.1 -mmengine==0.10.4 +onnx +tqdm +onnxsim +mmdet==3.3.0 +mmdeploy==1.3.1 +mmengine==0.10.4 diff --git a/models/cv/object_detection/fcos/ixrt/requirements.txt b/models/cv/object_detection/fcos/ixrt/requirements.txt index 3a911f40d22c62d06c2a2be249831156de20c265..a0763974b54feecde9c5a7654327122855e85eed 100644 --- a/models/cv/object_detection/fcos/ixrt/requirements.txt +++ b/models/cv/object_detection/fcos/ixrt/requirements.txt @@ -1,10 +1,10 @@ -tqdm -onnx -onnxsim -ultralytics -pycocotools -addict -yapf -pycuda -mmdet==2.28.2 +tqdm +onnx +onnxsim +ultralytics +pycocotools +addict +yapf +pycuda +mmdet==2.28.2 opencv-python==4.6.0.66 \ No newline at end of file diff --git a/models/cv/object_detection/foveabox/igie/requirements.txt b/models/cv/object_detection/foveabox/igie/requirements.txt index 97ac9c0458744fb56d62781ffd96279f893817f3..073c19fba032df6cb08ccea0364a9d87103dcc60 100644 --- a/models/cv/object_detection/foveabox/igie/requirements.txt +++ b/models/cv/object_detection/foveabox/igie/requirements.txt @@ -1,6 +1,6 @@ -onnx -tqdm -onnxsim -mmdet -mmdeploy -mmengine +onnx +tqdm +onnxsim +mmdet +mmdeploy +mmengine diff --git a/models/cv/object_detection/fsaf/igie/requirements.txt b/models/cv/object_detection/fsaf/igie/requirements.txt index a26706ef5402ca820ed6e4ab952d876ec768b4eb..b6b3fff4aa66c7401a67b2874e40c3caf154a34d 100644 --- a/models/cv/object_detection/fsaf/igie/requirements.txt +++ b/models/cv/object_detection/fsaf/igie/requirements.txt @@ -1,6 +1,6 @@ -onnx -tqdm -onnxsim -mmdet==3.3.0 -mmdeploy==1.3.1 -mmengine==0.10.4 +onnx +tqdm +onnxsim +mmdet==3.3.0 +mmdeploy==1.3.1 +mmengine==0.10.4 diff --git a/models/cv/object_detection/fsaf/ixrt/requirements.txt b/models/cv/object_detection/fsaf/ixrt/requirements.txt index a26706ef5402ca820ed6e4ab952d876ec768b4eb..b6b3fff4aa66c7401a67b2874e40c3caf154a34d 100644 --- a/models/cv/object_detection/fsaf/ixrt/requirements.txt +++ b/models/cv/object_detection/fsaf/ixrt/requirements.txt @@ -1,6 +1,6 @@ -onnx -tqdm -onnxsim -mmdet==3.3.0 -mmdeploy==1.3.1 -mmengine==0.10.4 +onnx +tqdm +onnxsim +mmdet==3.3.0 +mmdeploy==1.3.1 +mmengine==0.10.4 diff --git a/models/cv/object_detection/hrnet/igie/requirements.txt b/models/cv/object_detection/hrnet/igie/requirements.txt index 97ac9c0458744fb56d62781ffd96279f893817f3..073c19fba032df6cb08ccea0364a9d87103dcc60 100644 --- a/models/cv/object_detection/hrnet/igie/requirements.txt +++ b/models/cv/object_detection/hrnet/igie/requirements.txt @@ -1,6 +1,6 @@ -onnx -tqdm -onnxsim -mmdet -mmdeploy -mmengine +onnx +tqdm +onnxsim +mmdet +mmdeploy +mmengine diff --git a/models/cv/object_detection/hrnet/ixrt/requirements.txt b/models/cv/object_detection/hrnet/ixrt/requirements.txt index 97ac9c0458744fb56d62781ffd96279f893817f3..073c19fba032df6cb08ccea0364a9d87103dcc60 100644 --- a/models/cv/object_detection/hrnet/ixrt/requirements.txt +++ b/models/cv/object_detection/hrnet/ixrt/requirements.txt @@ -1,6 +1,6 @@ -onnx -tqdm -onnxsim -mmdet -mmdeploy -mmengine +onnx +tqdm +onnxsim +mmdet +mmdeploy +mmengine diff --git a/models/cv/object_detection/paa/igie/requirements.txt b/models/cv/object_detection/paa/igie/requirements.txt index 97ac9c0458744fb56d62781ffd96279f893817f3..073c19fba032df6cb08ccea0364a9d87103dcc60 100644 --- a/models/cv/object_detection/paa/igie/requirements.txt +++ b/models/cv/object_detection/paa/igie/requirements.txt @@ -1,6 +1,6 @@ -onnx -tqdm -onnxsim -mmdet -mmdeploy -mmengine +onnx +tqdm +onnxsim +mmdet +mmdeploy +mmengine diff --git a/models/cv/object_detection/retinaface/igie/requirements.txt b/models/cv/object_detection/retinaface/igie/requirements.txt index 19045455ed1528b895ded8d5812235d884f37ade..1d7f9e0ba37c9e14882d25f2cad7b6326c681952 100644 --- a/models/cv/object_detection/retinaface/igie/requirements.txt +++ b/models/cv/object_detection/retinaface/igie/requirements.txt @@ -1,4 +1,4 @@ -onnx -tqdm -onnxsim -opencv-python==4.6.0.66 +onnx +tqdm +onnxsim +opencv-python==4.6.0.66 diff --git a/models/cv/object_detection/retinanet/igie/requirements.txt b/models/cv/object_detection/retinanet/igie/requirements.txt index 97ac9c0458744fb56d62781ffd96279f893817f3..073c19fba032df6cb08ccea0364a9d87103dcc60 100644 --- a/models/cv/object_detection/retinanet/igie/requirements.txt +++ b/models/cv/object_detection/retinanet/igie/requirements.txt @@ -1,6 +1,6 @@ -onnx -tqdm -onnxsim -mmdet -mmdeploy -mmengine +onnx +tqdm +onnxsim +mmdet +mmdeploy +mmengine diff --git a/models/cv/object_detection/rtmdet/igie/requirements.txt b/models/cv/object_detection/rtmdet/igie/requirements.txt index a26706ef5402ca820ed6e4ab952d876ec768b4eb..b6b3fff4aa66c7401a67b2874e40c3caf154a34d 100644 --- a/models/cv/object_detection/rtmdet/igie/requirements.txt +++ b/models/cv/object_detection/rtmdet/igie/requirements.txt @@ -1,6 +1,6 @@ -onnx -tqdm -onnxsim -mmdet==3.3.0 -mmdeploy==1.3.1 -mmengine==0.10.4 +onnx +tqdm +onnxsim +mmdet==3.3.0 +mmdeploy==1.3.1 +mmengine==0.10.4 diff --git a/models/cv/object_detection/sabl/igie/requirements.txt b/models/cv/object_detection/sabl/igie/requirements.txt index a26706ef5402ca820ed6e4ab952d876ec768b4eb..b6b3fff4aa66c7401a67b2874e40c3caf154a34d 100644 --- a/models/cv/object_detection/sabl/igie/requirements.txt +++ b/models/cv/object_detection/sabl/igie/requirements.txt @@ -1,6 +1,6 @@ -onnx -tqdm -onnxsim -mmdet==3.3.0 -mmdeploy==1.3.1 -mmengine==0.10.4 +onnx +tqdm +onnxsim +mmdet==3.3.0 +mmdeploy==1.3.1 +mmengine==0.10.4 diff --git a/models/cv/object_detection/yolov11/igie/requirements.txt b/models/cv/object_detection/yolov11/igie/requirements.txt index dc9f18f87843e5d8b5559fcdfbc44cad96698013..72c3e77cc563c17d86fe72609162b2b17e91c72f 100644 --- a/models/cv/object_detection/yolov11/igie/requirements.txt +++ b/models/cv/object_detection/yolov11/igie/requirements.txt @@ -1,4 +1,4 @@ -tqdm -onnx==1.13.0 -onnxsim==0.4.36 -ultralytics==8.3.59 +tqdm +onnx==1.13.0 +onnxsim==0.4.36 +ultralytics==8.3.59 diff --git a/models/cv/object_detection/yolov3/igie/requirements.txt b/models/cv/object_detection/yolov3/igie/requirements.txt index 2628a268a2467f0db27e37e39fa0544bd3fb74a1..171602527bdea43ee2216f9ad4629d83cfd92e38 100644 --- a/models/cv/object_detection/yolov3/igie/requirements.txt +++ b/models/cv/object_detection/yolov3/igie/requirements.txt @@ -1,5 +1,5 @@ -tqdm -onnx -onnxsim -ultralytics -pycocotools +tqdm +onnx +onnxsim +ultralytics +pycocotools diff --git a/models/cv/object_detection/yolov3/ixrt/requirements.txt b/models/cv/object_detection/yolov3/ixrt/requirements.txt index f2ec37c1129a168dac9920da4cebdfe78169841a..b0f4374b2b778c81875da50d088fecedd01689c9 100644 --- a/models/cv/object_detection/yolov3/ixrt/requirements.txt +++ b/models/cv/object_detection/yolov3/ixrt/requirements.txt @@ -1,7 +1,7 @@ -tqdm -onnx -onnxsim -ultralytics -pycocotools -opencv-python==4.6.0.66 +tqdm +onnx +onnxsim +ultralytics +pycocotools +opencv-python==4.6.0.66 pycuda \ No newline at end of file diff --git a/models/cv/object_detection/yolov4/igie/requirements.txt b/models/cv/object_detection/yolov4/igie/requirements.txt index 83d1bed634d31ba3adccbeefd7e3b907d4f4d13f..238c13c62c610c88e9a33ab455ecab236cf72832 100644 --- a/models/cv/object_detection/yolov4/igie/requirements.txt +++ b/models/cv/object_detection/yolov4/igie/requirements.txt @@ -1,4 +1,4 @@ -tqdm -onnx -onnxsim -pycocotools +tqdm +onnx +onnxsim +pycocotools diff --git a/models/cv/object_detection/yolov4/ixrt/requirements.txt b/models/cv/object_detection/yolov4/ixrt/requirements.txt index c5ff461d18d82755007c5cc116aedbeb8a02574e..9dcd9ab72944e697ad89bdd11e924e519bf6f334 100644 --- a/models/cv/object_detection/yolov4/ixrt/requirements.txt +++ b/models/cv/object_detection/yolov4/ixrt/requirements.txt @@ -1,5 +1,5 @@ -tqdm -onnx -onnxsim -pycocotools +tqdm +onnx +onnxsim +pycocotools pycuda \ No newline at end of file diff --git a/models/cv/object_detection/yolov5/igie/requirements.txt b/models/cv/object_detection/yolov5/igie/requirements.txt index 2628a268a2467f0db27e37e39fa0544bd3fb74a1..171602527bdea43ee2216f9ad4629d83cfd92e38 100644 --- a/models/cv/object_detection/yolov5/igie/requirements.txt +++ b/models/cv/object_detection/yolov5/igie/requirements.txt @@ -1,5 +1,5 @@ -tqdm -onnx -onnxsim -ultralytics -pycocotools +tqdm +onnx +onnxsim +ultralytics +pycocotools diff --git a/models/cv/object_detection/yolov5/ixrt/requirements.txt b/models/cv/object_detection/yolov5/ixrt/requirements.txt index f2ec37c1129a168dac9920da4cebdfe78169841a..b0f4374b2b778c81875da50d088fecedd01689c9 100644 --- a/models/cv/object_detection/yolov5/ixrt/requirements.txt +++ b/models/cv/object_detection/yolov5/ixrt/requirements.txt @@ -1,7 +1,7 @@ -tqdm -onnx -onnxsim -ultralytics -pycocotools -opencv-python==4.6.0.66 +tqdm +onnx +onnxsim +ultralytics +pycocotools +opencv-python==4.6.0.66 pycuda \ No newline at end of file diff --git a/models/cv/object_detection/yolov5s/ixrt/requirements.txt b/models/cv/object_detection/yolov5s/ixrt/requirements.txt index a6188db8f77c90851c51ba2066e00ada54bdea98..ffb8ce179fef26f79070045778708b03b8111fce 100644 --- a/models/cv/object_detection/yolov5s/ixrt/requirements.txt +++ b/models/cv/object_detection/yolov5s/ixrt/requirements.txt @@ -1,6 +1,6 @@ -tqdm -onnx -onnxsim -ultralytics -pycocotools +tqdm +onnx +onnxsim +ultralytics +pycocotools pycuda \ No newline at end of file diff --git a/models/cv/object_detection/yolov6/igie/requirements.txt b/models/cv/object_detection/yolov6/igie/requirements.txt index 83d1bed634d31ba3adccbeefd7e3b907d4f4d13f..238c13c62c610c88e9a33ab455ecab236cf72832 100644 --- a/models/cv/object_detection/yolov6/igie/requirements.txt +++ b/models/cv/object_detection/yolov6/igie/requirements.txt @@ -1,4 +1,4 @@ -tqdm -onnx -onnxsim -pycocotools +tqdm +onnx +onnxsim +pycocotools diff --git a/models/cv/object_detection/yolov6/ixrt/requirements.txt b/models/cv/object_detection/yolov6/ixrt/requirements.txt index 2c833a42f5277d91744ea3412a828f6fc55acb01..dc83ddafae43070cd1a5eba3eacfd36fd5b9e8fa 100644 --- a/models/cv/object_detection/yolov6/ixrt/requirements.txt +++ b/models/cv/object_detection/yolov6/ixrt/requirements.txt @@ -1,6 +1,6 @@ -tqdm -onnx -onnxsim -pycocotools -pycuda +tqdm +onnx +onnxsim +pycocotools +pycuda numpy==1.24.0 \ No newline at end of file diff --git a/models/cv/object_detection/yolov7/igie/requirements.txt b/models/cv/object_detection/yolov7/igie/requirements.txt index d808c487baf6ec1db8d4edcd4d6dc1b62905b38a..ba3dfe4661bc285976ec15cc894351debd93aab8 100644 --- a/models/cv/object_detection/yolov7/igie/requirements.txt +++ b/models/cv/object_detection/yolov7/igie/requirements.txt @@ -1,5 +1,5 @@ -tqdm -onnx -onnxsim -pycocotools -seaborn +tqdm +onnx +onnxsim +pycocotools +seaborn diff --git a/models/cv/object_detection/yolov7/ixrt/requirements.txt b/models/cv/object_detection/yolov7/ixrt/requirements.txt index f2ec37c1129a168dac9920da4cebdfe78169841a..b0f4374b2b778c81875da50d088fecedd01689c9 100644 --- a/models/cv/object_detection/yolov7/ixrt/requirements.txt +++ b/models/cv/object_detection/yolov7/ixrt/requirements.txt @@ -1,7 +1,7 @@ -tqdm -onnx -onnxsim -ultralytics -pycocotools -opencv-python==4.6.0.66 +tqdm +onnx +onnxsim +ultralytics +pycocotools +opencv-python==4.6.0.66 pycuda \ No newline at end of file diff --git a/models/cv/object_detection/yolov8/igie/requirements.txt b/models/cv/object_detection/yolov8/igie/requirements.txt index fb1b2218f3198ddb5ee4a5397309a6af80924dc1..029e0cd9cc2d673e7e3d8959b41b17f78ac9501a 100644 --- a/models/cv/object_detection/yolov8/igie/requirements.txt +++ b/models/cv/object_detection/yolov8/igie/requirements.txt @@ -1,5 +1,5 @@ -tqdm -onnx -pycocotools -# FAILed in 8.2.51 -ultralytics==8.1.34 +tqdm +onnx +pycocotools +# FAILed in 8.2.51 +ultralytics==8.1.34 diff --git a/models/cv/object_detection/yolov9/igie/requirements.txt b/models/cv/object_detection/yolov9/igie/requirements.txt index 2918f4590c83cc745c902373ee939d5c5f16e9eb..1d97d8213c3e9112c60eab18ce7506cd6867790e 100644 --- a/models/cv/object_detection/yolov9/igie/requirements.txt +++ b/models/cv/object_detection/yolov9/igie/requirements.txt @@ -1,4 +1,4 @@ -onnx -tqdm -onnxsim -ultralytics==8.2.51 +onnx +tqdm +onnxsim +ultralytics==8.2.51 diff --git a/models/cv/object_detection/yolox/igie/requirements.txt b/models/cv/object_detection/yolox/igie/requirements.txt index 83d1bed634d31ba3adccbeefd7e3b907d4f4d13f..238c13c62c610c88e9a33ab455ecab236cf72832 100644 --- a/models/cv/object_detection/yolox/igie/requirements.txt +++ b/models/cv/object_detection/yolox/igie/requirements.txt @@ -1,4 +1,4 @@ -tqdm -onnx -onnxsim -pycocotools +tqdm +onnx +onnxsim +pycocotools diff --git a/models/cv/pose_estimation/hrnetpose/igie/requirements.txt b/models/cv/pose_estimation/hrnetpose/igie/requirements.txt index d4c10d62aa649e117305f47a019157fcd6379f9e..f08199360fb320885a92a3dd74aa786fc415b576 100644 --- a/models/cv/pose_estimation/hrnetpose/igie/requirements.txt +++ b/models/cv/pose_estimation/hrnetpose/igie/requirements.txt @@ -1,7 +1,7 @@ -onnx -tqdm -onnxsim -mmdet==3.3.0 -mmpose==1.3.1 -mmdeploy==1.3.1 -mmengine==0.10.4 +onnx +tqdm +onnxsim +mmdet==3.3.0 +mmpose==1.3.1 +mmdeploy==1.3.1 +mmengine==0.10.4 diff --git a/models/cv/pose_estimation/rtmpose/igie/requirements.txt b/models/cv/pose_estimation/rtmpose/igie/requirements.txt index d4c10d62aa649e117305f47a019157fcd6379f9e..f08199360fb320885a92a3dd74aa786fc415b576 100644 --- a/models/cv/pose_estimation/rtmpose/igie/requirements.txt +++ b/models/cv/pose_estimation/rtmpose/igie/requirements.txt @@ -1,7 +1,7 @@ -onnx -tqdm -onnxsim -mmdet==3.3.0 -mmpose==1.3.1 -mmdeploy==1.3.1 -mmengine==0.10.4 +onnx +tqdm +onnxsim +mmdet==3.3.0 +mmpose==1.3.1 +mmdeploy==1.3.1 +mmengine==0.10.4 diff --git a/models/nlp/plm/bert_base_ner/igie/Int8QAT/bert4torch/__init__.py b/models/nlp/plm/bert_base_ner/igie/Int8QAT/bert4torch/__init__.py index a34d4ecd3bb2428842260045ca2706b0638a8219..b7e9b861c2455612203c316987db0f99bc872df5 100644 --- a/models/nlp/plm/bert_base_ner/igie/Int8QAT/bert4torch/__init__.py +++ b/models/nlp/plm/bert_base_ner/igie/Int8QAT/bert4torch/__init__.py @@ -1,16 +1,16 @@ -# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. -# All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); you may -# not use this file except in compliance with the License. You may obtain -# a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -# License for the specific language governing permissions and limitations -# under the License. - -__version__ = "0.1.9" +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +__version__ = "0.1.9" diff --git a/models/nlp/plm/bert_base_ner/igie/Int8QAT/bert4torch/layers.py b/models/nlp/plm/bert_base_ner/igie/Int8QAT/bert4torch/layers.py index f7449310ca09a0ad0b4ce10de705246df64e052c..443d790979353be624b243d913c5d9d35ead5fc4 100644 --- a/models/nlp/plm/bert_base_ner/igie/Int8QAT/bert4torch/layers.py +++ b/models/nlp/plm/bert_base_ner/igie/Int8QAT/bert4torch/layers.py @@ -1,1667 +1,1667 @@ -# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. -# All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); you may -# not use this file except in compliance with the License. You may obtain -# a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -# License for the specific language governing permissions and limitations -# under the License. - -import math -import random -import warnings -from typing import List, Optional - -import numpy as np -import torch -import torch.nn as nn -import torch.nn.functional as F -from bert4torch.activations import get_activation -from bert4torch.snippets import get_sinusoid_encoding_table, take_along_dim -from torch.functional import Tensor - - -class LayerNorm(nn.Module): - def __init__( - self, - hidden_size, - eps=1e-12, - conditional_size=False, - weight=True, - bias=True, - norm_mode="normal", - **kwargs, - ): - super(LayerNorm, self).__init__() - - if weight: - self.weight = nn.Parameter(torch.ones(hidden_size)) - if bias: - self.bias = nn.Parameter(torch.zeros(hidden_size)) - self.norm_mode = norm_mode - - self.eps = eps - self.conditional_size = conditional_size - if conditional_size: - self.dense1 = nn.Linear(conditional_size, hidden_size, bias=False) - self.dense1.weight.data.uniform_(0, 0) - self.dense2 = nn.Linear(conditional_size, hidden_size, bias=False) - self.dense2.weight.data.uniform_(0, 0) - - def forward(self, x): - inputs = x[0] - - if self.norm_mode == "rmsnorm": - variance = inputs.to(torch.float32).pow(2).mean(-1, keepdim=True) - o = inputs * torch.rsqrt(variance + self.eps) - else: - u = inputs.mean(-1, keepdim=True) - s = (inputs - u).pow(2).mean(-1, keepdim=True) - o = (inputs - u) / torch.sqrt(s + self.eps) - - if not hasattr(self, "weight"): - self.weight = 1 - if not hasattr(self, "bias"): - self.bias = 0 - - if self.conditional_size: - cond = x[1] - for _ in range(len(inputs.shape) - len(cond.shape)): - cond = cond.unsqueeze(dim=1) - return (self.weight + self.dense1(cond)) * o + ( - self.bias + self.dense2(cond) - ) - else: - return self.weight * o + self.bias - - -class MultiHeadAttentionLayer(nn.Module): - def __init__( - self, - hidden_size, - num_attention_heads, - attention_probs_dropout_prob, - attention_scale=True, - return_attention_scores=False, - bias=True, - **kwargs, - ): - super(MultiHeadAttentionLayer, self).__init__() - - assert hidden_size % num_attention_heads == 0 - - self.hidden_size = hidden_size - self.num_attention_heads = num_attention_heads - self.attention_head_size = int(hidden_size / num_attention_heads) - self.attention_scale = attention_scale - self.return_attention_scores = return_attention_scores - - self.bias = bias - self.q = nn.Linear(hidden_size, hidden_size, bias=bias) - self.k = nn.Linear(hidden_size, hidden_size, bias=bias) - self.v = nn.Linear(hidden_size, hidden_size, bias=bias) - self.o = nn.Linear(hidden_size, hidden_size, bias=bias) - self.dropout = nn.Dropout(attention_probs_dropout_prob) - - self.a_bias, self.p_bias = kwargs.get("a_bias"), kwargs.get("p_bias") - - if self.p_bias == "typical_relative": # nezha - self.relative_positions_encoding = RelativePositionsEncoding( - qlen=kwargs.get("max_position"), - klen=kwargs.get("max_position"), - embedding_size=self.attention_head_size, - max_relative_position=kwargs.get("max_relative_position"), - ) - elif self.p_bias == "rotary": # roformer - self.relative_positions_encoding = RoPEPositionEncoding( - max_position=kwargs.get("max_position"), - embedding_size=self.attention_head_size, - ) - elif self.p_bias == "t5_relative": # t5 - self.relative_positions = RelativePositionsEncodingT5( - qlen=kwargs.get("max_position"), - klen=kwargs.get("max_position"), - relative_attention_num_buckets=kwargs.get( - "relative_attention_num_buckets" - ), - is_decoder=kwargs.get("is_decoder"), - ) - self.relative_positions_encoding = nn.Embedding( - kwargs.get("relative_attention_num_buckets"), self.num_attention_heads - ) - - def transpose_for_scores(self, x): - new_x_shape = x.size()[:-1] + ( - self.num_attention_heads, - self.attention_head_size, - ) - x = x.view(*new_x_shape) - return x.permute(0, 2, 1, 3) - - def forward( - self, - hidden_states, - attention_mask=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - ): - mixed_query_layer = self.q(hidden_states) - if encoder_hidden_states is not None: - mixed_key_layer = self.k(encoder_hidden_states) - mixed_value_layer = self.v(encoder_hidden_states) - attention_mask = encoder_attention_mask - else: - mixed_key_layer = self.k(hidden_states) - mixed_value_layer = self.v(hidden_states) - - query_layer = self.transpose_for_scores(mixed_query_layer) - key_layer = self.transpose_for_scores(mixed_key_layer) - value_layer = self.transpose_for_scores(mixed_value_layer) - - if self.p_bias == "rotary": - query_layer = self.relative_positions_encoding(query_layer) - key_layer = self.relative_positions_encoding(key_layer) - - attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) - - if (self.p_bias == "typical_relative") and hasattr( - self, "relative_positions_encoding" - ): - relations_keys = self.relative_positions_encoding( - attention_scores.shape[-1], attention_scores.shape[-1] - ) - key_position_scores_r_t = torch.einsum( - "bnih,ijh->bnij", query_layer, relations_keys - ) - attention_scores = attention_scores + key_position_scores_r_t - elif (self.p_bias == "t5_relative") and hasattr( - self, "relative_positions_encoding" - ): - relations_keys = self.relative_positions( - attention_scores.shape[-1], attention_scores.shape[-1] - ) - key_position_scores_r_t = ( - self.relative_positions_encoding(relations_keys) - .permute([2, 0, 1]) - .unsqueeze(0) - ) - attention_scores = attention_scores + key_position_scores_r_t - - if self.attention_scale: - attention_scores = attention_scores / math.sqrt(self.attention_head_size) - if attention_mask is not None: - attention_mask = ( - 1.0 - attention_mask - ) * -10000.0 - attention_scores = attention_scores + attention_mask - - attention_probs = F.softmax(attention_scores, dim=-1) - attention_probs = self.dropout(attention_probs) - context_layer = torch.matmul( - attention_probs, value_layer - ) # [batch_size, num_attention_heads, query_len, attention_head_size] - - if (self.p_bias == "typical_relative") and hasattr( - self, "relative_positions_encoding" - ): - relations_values = self.relative_positions_encoding( - attention_scores.shape[-1], attention_scores.shape[-1] - ) - value_position_scores_r_t = torch.einsum( - "bnij,ijh->bnih", attention_probs, relations_values - ) - context_layer = context_layer + value_position_scores_r_t - - context_layer = context_layer.permute(0, 2, 1, 3).contiguous() - - new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size,) - context_layer = context_layer.view(*new_context_layer_shape) - - if self.return_attention_scores: - return self.o(context_layer), attention_scores - else: - return self.o(context_layer) - - -class PositionWiseFeedForward(nn.Module): - def __init__( - self, - hidden_size, - intermediate_size, - dropout_rate=0.5, - hidden_act="gelu", - is_dropout=False, - bias=True, - **kwargs, - ): - super(PositionWiseFeedForward, self).__init__() - - self.is_dropout = is_dropout - self.intermediate_act_fn = get_activation(hidden_act) - self.intermediateDense = nn.Linear(hidden_size, intermediate_size, bias=bias) - self.outputDense = nn.Linear(intermediate_size, hidden_size, bias=bias) - if self.is_dropout: - self.dropout = nn.Dropout(dropout_rate) - - def forward(self, x): - # x shape: (batch size, seq len, hidden_size) - if self.is_dropout: - x = self.dropout(self.intermediate_act_fn(self.intermediateDense(x))) - else: - x = self.intermediate_act_fn(self.intermediateDense(x)) - - # x shape: (batch size, seq len, intermediate_size) - x = self.outputDense(x) - - # x shape: (batch size, seq len, hidden_size) - return x - - -class GatedAttentionUnit(nn.Module): - def __init__( - self, - hidden_size, - attention_key_size, - intermediate_size, - attention_probs_dropout_prob, - hidden_act, - is_dropout=False, - attention_scale=True, - bias=True, - normalization="softmax_plus", - **kwargs, - ): - super().__init__() - self.intermediate_size = intermediate_size - self.attention_head_size = attention_key_size - self.attention_scale = attention_scale - self.is_dropout = is_dropout - self.normalization = normalization - self.hidden_fn = get_activation(hidden_act) - self.dropout = nn.Dropout(attention_probs_dropout_prob) - self.i_dense = nn.Linear( - hidden_size, self.intermediate_size * 2 + attention_key_size, bias=bias - ) - self.offsetscale = self.OffsetScale(attention_key_size, heads=2, bias=bias) - self.o_dense = nn.Linear(self.intermediate_size, hidden_size, bias=bias) - - self.a_bias, self.p_bias = kwargs.get("a_bias"), kwargs.get("p_bias") - if self.p_bias == "rotary": # RoPE - self.relative_positions_encoding = RoPEPositionEncoding( - max_position=kwargs.get("max_position"), - embedding_size=self.attention_head_size, - ) - - def forward(self, hidden_states, attention_mask): - hidden_states = self.hidden_fn(self.i_dense(hidden_states)) - u, v, qk = hidden_states.split( - [self.intermediate_size, self.intermediate_size, self.attention_head_size], - dim=-1, - ) - q, k = self.offsetscale(qk) - - if self.p_bias == "rotary": - q = self.relative_positions_encoding(q) - k = self.relative_positions_encoding(k) - - # Attention - attention_scores = torch.einsum( - "b i d, b j d -> b i j", q, k - ) # [btz, seq_len, seq_len] - if self.attention_scale: - attention_scores = attention_scores / math.sqrt(self.attention_head_size) - - if attention_mask is not None: - attention_mask = (1.0 - attention_mask) * -1e12 - attention_scores = attention_scores + attention_mask.squeeze(1) - - # 归一化 - attention_scores = self.attention_normalize( - attention_scores, -1, self.normalization - ) - - if self.is_dropout: - attention_scores = self.dropout(attention_scores) - - # 计算输出 - out = self.o_dense( - u * torch.einsum("b i j, b j d -> b i d", attention_scores, v) - ) - return out - - def attention_normalize(self, a, dim=-1, method="softmax"): - if method == "softmax": - return F.softmax(a, dim=dim) - else: - mask = (a > -1e11).float() - l = torch.maximum( - torch.sum(mask, dim=dim, keepdims=True), torch.tensor(1).to(mask) - ) - if method == "squared_relu": - return F.relu(a) ** 2 / l - elif method == "softmax_plus": - return F.softmax( - a * torch.log(l) / torch.log(torch.tensor(512)).to(mask), dim=dim - ) - return a - - class OffsetScale(nn.Module): - def __init__(self, head_size, heads=1, bias=True): - super().__init__() - self.gamma = nn.Parameter(torch.ones(heads, head_size)) - self.bias = bias - if bias: - self.beta = nn.Parameter(torch.zeros(heads, head_size)) - nn.init.normal_(self.gamma, std=0.02) - - def forward(self, x): - out = torch.einsum("... d, h d -> ... h d", x, self.gamma) - if self.bias: - out = out + self.beta - return out.unbind(dim=-2) - - -class BertEmbeddings(nn.Module): - def __init__( - self, - vocab_size, - embedding_size, - hidden_size, - max_position, - segment_vocab_size, - shared_segment_embeddings, - drop_rate, - conditional_size=False, - **kwargs, - ): - super(BertEmbeddings, self).__init__() - self.shared_segment_embeddings = shared_segment_embeddings - self.word_embeddings = nn.Embedding(vocab_size, embedding_size, padding_idx=0) - - if kwargs.get("p_bias") == "sinusoid": - self.position_embeddings = SinusoidalPositionEncoding( - max_position, embedding_size - ) - elif kwargs.get("p_bias") in { - "rotary", - "typical_relative", - "t5_relative", - "other_relative", - }: - pass - elif max_position > 0: - self.position_embeddings = nn.Embedding(max_position, embedding_size) - - if (segment_vocab_size > 0) and (not shared_segment_embeddings): - self.segment_embeddings = nn.Embedding(segment_vocab_size, embedding_size) - - # emb_scale - self.emb_scale = kwargs.get("emb_scale", 1) - - # LayerNorm - self.layerNorm = LayerNorm( - embedding_size, eps=1e-12, conditional_size=conditional_size, **kwargs - ) - self.dropout = nn.Dropout(drop_rate) - - if embedding_size != hidden_size: - self.embedding_hidden_mapping_in = nn.Linear(embedding_size, hidden_size) - - def forward( - self, token_ids, segment_ids=None, conditional_emb=None, additional_embs=None - ): - if (not token_ids.requires_grad) and ( - token_ids.dtype in {torch.long, torch.int} - ): - words_embeddings = self.word_embeddings(token_ids) - else: - words_embeddings = token_ids - - if hasattr(self, "segment_embeddings"): - segment_ids = ( - torch.zeros_like(token_ids) if segment_ids is None else segment_ids - ) - segment_embeddings = self.segment_embeddings(segment_ids) - embeddings = words_embeddings + segment_embeddings - elif self.shared_segment_embeddings: - segment_ids = ( - torch.zeros_like(token_ids) if segment_ids is None else segment_ids - ) - segment_embeddings = self.word_embeddings(segment_ids) - embeddings = words_embeddings + segment_embeddings - else: - embeddings = words_embeddings - - if additional_embs is not None: - for emb in additional_embs: - embeddings += emb - - if hasattr(self, "position_embeddings"): - seq_length = token_ids.size(1) - position_ids = torch.arange( - seq_length, dtype=torch.long, device=token_ids.device - ) - position_ids = position_ids.unsqueeze(0).repeat(token_ids.shape[0], 1) - position_embeddings = self.position_embeddings(position_ids) - embeddings += position_embeddings - - if self.emb_scale != 1: - embeddings = embeddings * self.emb_scale - - if hasattr(self, "layerNorm"): - embeddings = self.layerNorm((embeddings, conditional_emb)) - embeddings = self.dropout(embeddings) - - if hasattr(self, "embedding_hidden_mapping_in"): - embeddings = self.embedding_hidden_mapping_in(embeddings) - return embeddings - - -class BertLayer(nn.Module): - def __init__( - self, - hidden_size, - num_attention_heads, - dropout_rate, - attention_probs_dropout_prob, - intermediate_size, - hidden_act, - is_dropout=False, - conditional_size=False, - **kwargs, - ): - super(BertLayer, self).__init__() - self.multiHeadAttention = MultiHeadAttentionLayer( - hidden_size, num_attention_heads, attention_probs_dropout_prob, **kwargs - ) - self.dropout1 = nn.Dropout(dropout_rate) - self.layerNorm1 = LayerNorm( - hidden_size, eps=1e-12, conditional_size=conditional_size, **kwargs - ) - self.feedForward = PositionWiseFeedForward( - hidden_size, - intermediate_size, - dropout_rate, - hidden_act, - is_dropout=is_dropout, - **kwargs, - ) - self.dropout2 = nn.Dropout(dropout_rate) - self.layerNorm2 = LayerNorm( - hidden_size, eps=1e-12, conditional_size=conditional_size, **kwargs - ) - self.is_decoder = kwargs.get("is_decoder") - if self.is_decoder: - self.crossAttention = MultiHeadAttentionLayer( - hidden_size, num_attention_heads, attention_probs_dropout_prob, **kwargs - ) - self.dropout3 = nn.Dropout(dropout_rate) - self.layerNorm3 = LayerNorm( - hidden_size, eps=1e-12, conditional_size=conditional_size, **kwargs - ) - - def forward( - self, - hidden_states, - attention_mask, - conditional_emb=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - ): - self_attn_output = self.multiHeadAttention( - hidden_states, attention_mask - ) - hidden_states = hidden_states + self.dropout1(self_attn_output) - hidden_states = self.layerNorm1((hidden_states, conditional_emb)) - - # cross attention - if self.is_decoder and encoder_hidden_states is not None: - cross_attn_output = self.crossAttention( - hidden_states, None, encoder_hidden_states, encoder_attention_mask - ) - hidden_states = hidden_states + self.dropout3(cross_attn_output) - hidden_states = self.layerNorm3((hidden_states, conditional_emb)) - - self_attn_output2 = self.feedForward(hidden_states) - hidden_states = hidden_states + self.dropout2(self_attn_output2) - hidden_states = self.layerNorm2((hidden_states, conditional_emb)) - return hidden_states - - -class T5Layer(BertLayer): - def __init__(self, *args, version="t5.1.0", **kwargs): - super().__init__(*args, **kwargs) - - if version.endswith("t5.1.1"): - kwargs["dropout_rate"] = args[2] - kwargs["hidden_act"] = args[5] - self.feedForward = self.T5PositionWiseFeedForward( - hidden_size=args[0], intermediate_size=args[4], **kwargs - ) - - if self.is_decoder and hasattr( - self.crossAttention, "relative_positions_encoding" - ): - del self.crossAttention.relative_positions_encoding - del self.crossAttention.relative_positions - - def forward( - self, - hidden_states, - attention_mask, - conditional_emb=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - ): - x = self.layerNorm1((hidden_states, conditional_emb)) - self_attn_output = self.multiHeadAttention(x, attention_mask) - hidden_states = hidden_states + self.dropout1(self_attn_output) - - # cross attention - if self.is_decoder and encoder_hidden_states is not None: - x = self.layerNorm3((hidden_states, conditional_emb)) - cross_attn_output = self.crossAttention( - x, None, encoder_hidden_states, encoder_attention_mask - ) - hidden_states = hidden_states + self.dropout3(cross_attn_output) - - x = self.layerNorm2((hidden_states, conditional_emb)) - ffn_output = self.feedForward(x) - hidden_states = hidden_states + self.dropout2(ffn_output) - return hidden_states - - class T5PositionWiseFeedForward(PositionWiseFeedForward): - def __init__(self, hidden_size, intermediate_size, **kwargs): - super().__init__(hidden_size, intermediate_size, **kwargs) - self.intermediateDense = nn.Linear( - hidden_size, intermediate_size, bias=False - ) - self.intermediateDense1 = nn.Linear( - hidden_size, intermediate_size, bias=False - ) - self.outputDense = nn.Linear(intermediate_size, hidden_size, bias=False) - - def forward(self, x): - # x shape: (batch size, seq len, hidden_size) - x_gelu = self.intermediate_act_fn(self.intermediateDense(x)) - x_linear = self.intermediateDense1(x) - x = x_gelu * x_linear - if self.is_dropout: - x = self.dropout(x) - - # x shape: (batch size, seq len, intermediate_size) - x = self.outputDense(x) - - # x shape: (batch size, seq len, hidden_size) - return x - - -class XlnetLayer(BertLayer): - def __init__( - self, - hidden_size, - num_attention_heads, - dropout_rate, - attention_probs_dropout_prob, - intermediate_size, - hidden_act, - **kwargs, - ): - super().__init__( - hidden_size, - num_attention_heads, - dropout_rate, - attention_probs_dropout_prob, - intermediate_size, - hidden_act, - **kwargs, - ) - self.pre_lnorm = kwargs.get("pre_lnorm") - self.multiHeadAttention = self.RelPartialLearnableMultiHeadAttn( - hidden_size, - num_attention_heads, - attention_probs_dropout_prob, - bias=False, - **kwargs, - ) - - def forward( - self, - hidden_states, - segment_ids, - pos_emb, - attention_mask, - mems_i, - conditional_emb=None, - ): - hidden_states_cat = ( - torch.cat([mems_i, hidden_states], 1) - if mems_i is not None - else hidden_states - ) - - # Attn - if self.pre_lnorm: - hidden_states_cat = self.layerNorm1((hidden_states_cat, conditional_emb)) - self_attn_output = self.multiHeadAttention( - hidden_states, hidden_states_cat, pos_emb, attention_mask, segment_ids - ) - hidden_states = hidden_states + self.dropout1(self_attn_output) - if not self.pre_lnorm: # post_lnorm - hidden_states = self.layerNorm1((hidden_states, conditional_emb)) - - # FFN - x = ( - self.layerNorm2((hidden_states, conditional_emb)) - if self.pre_lnorm - else hidden_states - ) - self_attn_output2 = self.feedForward(x) - hidden_states = hidden_states + self.dropout2(self_attn_output2) - if not self.pre_lnorm: # post_lnorm - hidden_states = self.layerNorm2((hidden_states, conditional_emb)) - return hidden_states - - class RelPartialLearnableMultiHeadAttn(MultiHeadAttentionLayer): - - def __init__( - self, *args, r_w_bias=None, r_r_bias=None, r_s_bias=None, **kwargs - ): - super().__init__(*args, **kwargs) - segment_vocab_size = kwargs.get("segment_vocab_size") - if r_r_bias is None or r_w_bias is None: # Biases are not shared - self.r_r_bias = nn.Parameter( - torch.FloatTensor( - self.num_attention_heads, self.attention_head_size - ) - ) - self.r_w_bias = nn.Parameter( - torch.FloatTensor( - self.num_attention_heads, self.attention_head_size - ) - ) - if segment_vocab_size > 0: - self.r_s_bias = nn.Parameter( - torch.FloatTensor( - self.num_attention_heads, self.attention_head_size - ) - ) - else: - self.r_r_bias = r_r_bias - self.r_w_bias = r_w_bias - self.r_s_bias = r_s_bias - if segment_vocab_size > 0: - # self.seg_embed = nn.Embedding(segment_vocab_size, self.hidden_size) - self.seg_embed = nn.Parameter( - torch.FloatTensor( - segment_vocab_size, - self.num_attention_heads, - self.attention_head_size, - ) - ) - - self.r = nn.Linear(self.hidden_size, self.hidden_size, bias=self.bias) - self.rel_shift_opt = kwargs.get("rel_shift_opt") - - @staticmethod - def rel_shift(x, zero_triu=False): - q_len, k_len = x.size(2), x.size(-1) - zero_pad = torch.zeros( - (*x.size()[:2], q_len, 1), device=x.device, dtype=x.dtype - ) - x_padded = torch.cat([zero_pad, x], dim=-1) - x_padded = x_padded.view(*x.size()[:2], k_len + 1, q_len) - x = x_padded[:, :, 1:, :].view_as(x) - if zero_triu: - ones = torch.ones((q_len, k_len), device=x.device) - x = x * torch.tril(ones, k_len - q_len)[None, None, :, :] - return x - - @staticmethod - def rel_shift_bnij(x, klen=-1): - x_size = x.shape - x = x.reshape(x_size[0], x_size[1], x_size[3], x_size[2]) - x = x[:, :, 1:, :] - x = x.reshape(x_size[0], x_size[1], x_size[2], x_size[3] - 1) - x = torch.index_select( - x, 3, torch.arange(klen, device=x.device, dtype=torch.long) - ) - return x - - def forward(self, w, cat, r, attention_mask=None, seg_mat=None): - qlen, rlen, bsz = w.size(1), r.size(0), w.size(0) - - mixed_query_layer = self.q(cat)[:, -qlen:, :] - mixed_key_layer = self.k(cat) - mixed_value_layer = self.v(cat) - - w_head_q = self.transpose_for_scores( - mixed_query_layer - ) # [btz, n_head, q_len, d_head] - w_head_k = self.transpose_for_scores( - mixed_key_layer - ) # [btz, n_head, k_len, d_head] - w_head_v = self.transpose_for_scores( - mixed_value_layer - ) # [btz, n_head, k_len, d_head] - - r_head_k = self.r(r) # [hdsz, nhead*headsize] = [r_len, 1, nhead*headsize] - r_head_k = r_head_k.view( - rlen, self.num_attention_heads, self.attention_head_size - ) # rlen x n_head x d_head - - #### compute attention score - rw_head_q = w_head_q + self.r_w_bias.unsqueeze( - 1 - ) # [btz, n_head, q_len, d_head] - AC = torch.einsum( - "bnid,bnjd->bnij", (rw_head_q, w_head_k) - ) # [btz, n_head, q_len, k_len] - - rr_head_q = w_head_q + self.r_r_bias.unsqueeze( - 1 - ) # [btz, n_head, q_len, d_head] - BD = torch.einsum( - "bnid,jnd->bnij", (rr_head_q, r_head_k) - ) # [btz, n_head, q_len, k_len] - BD = ( - self.rel_shift_bnij(BD, klen=AC.shape[3]) - if self.rel_shift_opt == "xlnet" - else self.rel_shift(BD) - ) - - if hasattr(self, "seg_embed") and (self.r_r_bias is not None): - seg_mat = F.one_hot(seg_mat, 2).float() - EF = torch.einsum( - "bnid,snd->ibns", - w_head_q + self.r_s_bias.unsqueeze(1), - self.seg_embed, - ) - EF = torch.einsum("bijs,ibns->bnij", seg_mat, EF) - else: - EF = 0 - - # # [btz, n_head, q_len, k_len] - attention_scores = AC + BD + EF - if self.attention_scale: - attention_scores = attention_scores / math.sqrt( - self.attention_head_size - ) - - if attention_mask is not None and attention_mask.any().item(): - attention_mask = 1.0 - attention_mask - attention_scores = ( - attention_scores.float() - .masked_fill(attention_mask.bool(), -1e30) - .type_as(attention_mask) - ) - - # [btz, n_head, q_len, k_len] - attention_probs = F.softmax(attention_scores, dim=-1) - attention_probs = self.dropout(attention_probs) - context_layer = torch.matmul( - attention_probs, w_head_v - ) # [batch_size, num_attention_heads, query_len, attention_head_size] - context_layer = context_layer.permute(0, 2, 1, 3).contiguous() - new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size,) - context_layer = context_layer.view(*new_context_layer_shape) - - if self.return_attention_scores: - return self.o(context_layer), attention_scores - else: - return self.o(context_layer) - - -class AdaptiveEmbedding(nn.Module): - def __init__( - self, - vocab_size, - embedding_size, - hidden_size, - cutoffs, - div_val=1, - sample_softmax=False, - **kwargs, - ): - super().__init__() - self.vocab_size = vocab_size - self.embedding_size = embedding_size - self.cutoffs = cutoffs + [vocab_size] - self.div_val = div_val - self.hidden_size = hidden_size - self.emb_scale = hidden_size**0.5 - self.cutoff_ends = [0] + self.cutoffs - - self.emb_layers = nn.ModuleList() - self.emb_projs = nn.ParameterList() - if div_val == 1: - self.emb_layers.append( - nn.Embedding(vocab_size, embedding_size, sparse=sample_softmax > 0) - ) - if hidden_size != embedding_size: - self.emb_projs.append( - nn.Parameter(torch.FloatTensor(hidden_size, embedding_size)) - ) - else: - for i in range(len(self.cutoffs)): - l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1] - d_emb_i = embedding_size // (div_val**i) - self.emb_layers.append(nn.Embedding(r_idx - l_idx, d_emb_i)) - self.emb_projs.append( - nn.Parameter(torch.FloatTensor(hidden_size, d_emb_i)) - ) - - def forward(self, token_ids): - if self.div_val == 1: - embed = self.emb_layers[0](token_ids) # [btz, seq_len, embedding_size] - if self.hidden_size != self.embedding_size: - embed = nn.functional.linear(embed, self.emb_projs[0]) - else: - param = next(self.parameters()) - inp_flat = token_ids.view(-1) - emb_flat = torch.zeros( - [inp_flat.size(0), self.hidden_size], - dtype=param.dtype, - device=param.device, - ) - for i in range(len(self.cutoffs)): - l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1] - - mask_i = (inp_flat >= l_idx) & (inp_flat < r_idx) - indices_i = mask_i.nonzero().squeeze() - - if indices_i.numel() == 0: - continue - - inp_i = inp_flat.index_select(0, indices_i) - l_idx - emb_i = self.emb_layers[i](inp_i) - emb_i = nn.functional.linear(emb_i, self.emb_projs[i]) - - emb_flat.index_copy_(0, indices_i, emb_i) - - embed_shape = token_ids.size() + (self.hidden_size,) - embed = emb_flat.view(embed_shape) - - embed.mul_(self.emb_scale) - - return embed - - -class Identity(nn.Module): - def __init__(self, *args, **kwargs): - super(Identity, self).__init__() - - def forward(self, *args): - return args[0] - - -class XlnetPositionsEncoding(nn.Module): - def __init__(self, embedding_size): - super().__init__() - self.demb = embedding_size - inv_freq = 1 / ( - 10000 ** (torch.arange(0.0, embedding_size, 2.0) / embedding_size) - ) - self.register_buffer("inv_freq", inv_freq) - - def forward(self, pos_seq): - sinusoid_inp = torch.ger(pos_seq, self.inv_freq) - pos_emb = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=-1) - return pos_emb - - -class RelativePositionsEncoding(nn.Module): - - def __init__(self, qlen, klen, embedding_size, max_relative_position=127): - super(RelativePositionsEncoding, self).__init__() - # 生成相对位置矩阵 - vocab_size = max_relative_position * 2 + 1 - distance_mat = ( - torch.arange(klen)[None, :] - torch.arange(qlen)[:, None] - ) # 列数-行数, [query_len, key_len] - distance_mat_clipped = torch.clamp( - distance_mat, -max_relative_position, max_relative_position - ) - final_mat = distance_mat_clipped + max_relative_position - - embeddings_table = get_sinusoid_encoding_table(vocab_size, embedding_size) - - position_embeddings = nn.Embedding.from_pretrained( - embeddings_table, freeze=True - )(final_mat) - self.register_buffer("position_embeddings", position_embeddings) - - def forward(self, qlen, klen): - return self.position_embeddings[:qlen, :klen, :] - - -class RelativePositionsEncodingT5(nn.Module): - - def __init__(self, qlen, klen, relative_attention_num_buckets, is_decoder=False): - super(RelativePositionsEncodingT5, self).__init__() - context_position = torch.arange(qlen, dtype=torch.long)[:, None] - memory_position = torch.arange(klen, dtype=torch.long)[None, :] - relative_position = memory_position - context_position # shape (qlen, klen) - relative_position = self._relative_position_bucket( - relative_position, # shape (qlen, klen) - bidirectional=not is_decoder, - num_buckets=relative_attention_num_buckets, - ) - self.register_buffer("relative_position", relative_position) - - def forward(self, qlen, klen): - return self.relative_position[:qlen, :klen] - - @staticmethod - def _relative_position_bucket( - relative_position, bidirectional=True, num_buckets=32, max_distance=128 - ): - """直接来源于transformer""" - ret = 0 - n = -relative_position - if bidirectional: - num_buckets //= 2 - ret += (n < 0).to( - torch.long - ) * num_buckets # mtf.to_int32(mtf.less(n, 0)) * num_buckets - n = torch.abs(n) - else: - n = torch.max(n, torch.zeros_like(n)) - # now n is in the range [0, inf) - - # half of the buckets are for exact increments in positions - max_exact = num_buckets // 2 - is_small = n < max_exact - - # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance - val_if_large = max_exact + ( - torch.log(n.float() / max_exact) - / math.log(max_distance / max_exact) - * (num_buckets - max_exact) - ).to(torch.long) - val_if_large = torch.min( - val_if_large, torch.full_like(val_if_large, num_buckets - 1) - ) - - ret += torch.where(is_small, n, val_if_large) - return ret - - -class SinusoidalPositionEncoding(nn.Module): - """定义Sin-Cos位置Embedding""" - - def __init__(self, max_position, embedding_size): - super(SinusoidalPositionEncoding, self).__init__() - self.position_embeddings = nn.Embedding.from_pretrained( - get_sinusoid_encoding_table(max_position, embedding_size), freeze=True - ) - - def forward(self, position_ids): - return self.position_embeddings(position_ids) - - -class RoPEPositionEncoding(nn.Module): - def __init__(self, max_position, embedding_size): - super(RoPEPositionEncoding, self).__init__() - position_embeddings = get_sinusoid_encoding_table( - max_position, embedding_size - ) # [seq_len, hdsz] - cos_position = position_embeddings[:, 1::2].repeat_interleave(2, dim=-1) - sin_position = position_embeddings[:, ::2].repeat_interleave(2, dim=-1) - self.register_buffer("cos_position", cos_position) - self.register_buffer("sin_position", sin_position) - - def forward(self, qw, seq_dim=-2): - seq_len = qw.shape[seq_dim] - qw2 = torch.stack([-qw[..., 1::2], qw[..., ::2]], dim=-1).reshape_as(qw) - return qw * self.cos_position[:seq_len] + qw2 * self.sin_position[:seq_len] - - -class CRF(nn.Module): - def __init__( - self, - num_tags: int, - init_transitions: Optional[List[np.ndarray]] = None, - freeze=False, - ) -> None: - if num_tags <= 0: - raise ValueError(f"invalid number of tags: {num_tags}") - super().__init__() - self.num_tags = num_tags - if (init_transitions is None) and (not freeze): - self.start_transitions = nn.Parameter(torch.empty(num_tags)) - self.end_transitions = nn.Parameter(torch.empty(num_tags)) - self.transitions = nn.Parameter(torch.empty(num_tags, num_tags)) - nn.init.uniform_(self.start_transitions, -0.1, 0.1) - nn.init.uniform_(self.end_transitions, -0.1, 0.1) - nn.init.uniform_(self.transitions, -0.1, 0.1) - elif init_transitions is not None: - transitions = torch.tensor(init_transitions[0], dtype=torch.float) - start_transitions = torch.tensor(init_transitions[1], dtype=torch.float) - end_transitions = torch.tensor(init_transitions[2], dtype=torch.float) - - if not freeze: - self.transitions = nn.Parameter(transitions) - self.start_transitions = nn.Parameter(start_transitions) - self.end_transitions = nn.Parameter(end_transitions) - else: - self.register_buffer("transitions", transitions) - self.register_buffer("start_transitions", start_transitions) - self.register_buffer("end_transitions", end_transitions) - - def __repr__(self) -> str: - return f"{self.__class__.__name__}(num_tags={self.num_tags})" - - def forward( - self, - emissions: torch.Tensor, - mask: torch.ByteTensor, - tags: torch.LongTensor, - reduction: str = "mean", - ) -> torch.Tensor: - """Compute the conditional log likelihood of a sequence of tags given emission scores. - emissions: [btz, seq_len, num_tags] - mask: [btz, seq_len] - tags: [btz, seq_len] - """ - if reduction not in ("none", "sum", "mean", "token_mean"): - raise ValueError(f"invalid reduction: {reduction}") - if mask.dtype != torch.uint8: - mask = mask.byte() - self._validate(emissions, tags=tags, mask=mask) - - # shape: (batch_size,) - numerator = self._compute_score(emissions, tags, mask) - # shape: (batch_size,) - denominator = self._compute_normalizer(emissions, mask) - # shape: (batch_size,) - llh = denominator - numerator - - if reduction == "none": - return llh - if reduction == "sum": - return llh.sum() - if reduction == "mean": - return llh.mean() - return llh.sum() / mask.float().sum() - - def decode( - self, - emissions: torch.Tensor, - mask: Optional[torch.ByteTensor] = None, - nbest: Optional[int] = None, - pad_tag: Optional[int] = None, - ) -> List[List[List[int]]]: - """Find the most likely tag sequence using Viterbi algorithm.""" - if nbest is None: - nbest = 1 - if mask is None: - mask = torch.ones( - emissions.shape[:2], dtype=torch.uint8, device=emissions.device - ) - if mask.dtype != torch.uint8: - mask = mask.byte() - self._validate(emissions, mask=mask) - - best_path = self._viterbi_decode_nbest(emissions, mask, nbest, pad_tag) - return best_path[0] if nbest == 1 else best_path - - def _validate( - self, - emissions: torch.Tensor, - tags: Optional[torch.LongTensor] = None, - mask: Optional[torch.ByteTensor] = None, - ) -> None: - if emissions.dim() != 3: - raise ValueError( - f"emissions must have dimension of 3, got {emissions.dim()}" - ) - if emissions.size(2) != self.num_tags: - raise ValueError( - f"expected last dimension of emissions is {self.num_tags}, " - f"got {emissions.size(2)}" - ) - if tags is not None: - if emissions.shape[:2] != tags.shape: - raise ValueError( - "the first two dimensions of emissions and tags must match, " - f"got {tuple(emissions.shape[:2])} and {tuple(tags.shape)}" - ) - if mask is not None: - if emissions.shape[:2] != mask.shape: - raise ValueError( - "the first two dimensions of emissions and mask must match, " - f"got {tuple(emissions.shape[:2])} and {tuple(mask.shape)}" - ) - no_empty_seq_bf = mask[:, 0].all() - if not no_empty_seq_bf: - raise ValueError("mask of the first timestep must all be on") - - def _compute_score( - self, emissions: torch.Tensor, tags: torch.LongTensor, mask: torch.ByteTensor - ) -> torch.Tensor: - # emissions: (batch_size, seq_length, num_tags) - # tags: (batch_size, seq_length) - # mask: (batch_size, seq_length) - batch_size, seq_length = tags.shape - mask = mask.float() - - # Start transition score and first emission - # shape: (batch_size,) - score = self.start_transitions[tags[:, 0]] - score += emissions[torch.arange(batch_size), 0, tags[:, 0]] - - for i in range(1, seq_length): - # Transition score to next tag, only added if next timestep is valid (mask == 1) - # shape: (batch_size,) - score += self.transitions[tags[:, i - 1], tags[:, i]] * mask[:, i] - # Emission score for next tag, only added if next timestep is valid (mask == 1) - # shape: (batch_size,) - score += emissions[torch.arange(batch_size), i, tags[:, i]] * mask[:, i] - - # End transition score - # shape: (batch_size,) - seq_ends = mask.long().sum(dim=1) - 1 - # shape: (batch_size,) - last_tags = tags[torch.arange(batch_size), seq_ends] - # shape: (batch_size,) - score += self.end_transitions[last_tags] - - return score - - def _compute_normalizer( - self, emissions: torch.Tensor, mask: torch.ByteTensor - ) -> torch.Tensor: - # emissions: (batch_size, seq_length, num_tags) - # mask: (batch_size, seq_length) - seq_length = emissions.size(1) - - # Start transition score and first emission; score has size of - # (batch_size, num_tags) where for each batch, the j-th column stores - # the score that the first timestep has tag j - # shape: (batch_size, num_tags) - score = self.start_transitions + emissions[:, 0] - - for i in range(1, seq_length): - # Broadcast score for every possible next tag - # shape: (batch_size, num_tags, 1) - broadcast_score = score.unsqueeze(2) - - # Broadcast emission score for every possible current tag - # shape: (batch_size, 1, num_tags) - broadcast_emissions = emissions[:, i].unsqueeze(1) - - # Compute the score tensor of size (batch_size, num_tags, num_tags) where - # for each sample, entry at row i and column j stores the sum of scores of all - # possible tag sequences so far that end with transitioning from tag i to tag j - # and emitting - # shape: (batch_size, num_tags, num_tags) - next_score = broadcast_score + self.transitions + broadcast_emissions - - # Sum over all possible current tags, but we're in score space, so a sum - # becomes a log-sum-exp: for each sample, entry i stores the sum of scores of - # all possible tag sequences so far, that end in tag i - # shape: (batch_size, num_tags) - next_score = torch.logsumexp(next_score, dim=1) - - # Set score to the next score if this timestep is valid (mask == 1) - # shape: (batch_size, num_tags) - score = torch.where(mask[:, i].unsqueeze(1).bool(), next_score, score) - - # End transition score - # shape: (batch_size, num_tags) - score += self.end_transitions - - # Sum (log-sum-exp) over all possible tags - # shape: (batch_size,) - return torch.logsumexp(score, dim=1) - - def _viterbi_decode_nbest( - self, - emissions: torch.FloatTensor, - mask: torch.ByteTensor, - nbest: int, - pad_tag: Optional[int] = None, - ) -> List[List[List[int]]]: - # emissions: (batch_size, seq_length, num_tags) - # mask: (batch_size, seq_length) - # return: (nbest, batch_size, seq_length) - if pad_tag is None: - pad_tag = 0 - - device = emissions.device - batch_size, seq_length = mask.shape - - # Start transition and first emission - # shape: (batch_size, num_tags) - score = self.start_transitions + emissions[:, 0] - history_idx = torch.zeros( - (batch_size, seq_length, self.num_tags, nbest), - dtype=torch.long, - device=device, - ) - oor_idx = torch.zeros( - (batch_size, self.num_tags, nbest), dtype=torch.long, device=device - ) - oor_tag = torch.full( - (batch_size, seq_length, nbest), pad_tag, dtype=torch.long, device=device - ) - - # - score is a tensor of size (batch_size, num_tags) where for every batch, - # value at column j stores the score of the best tag sequence so far that ends - # with tag j - # - history_idx saves where the best tags candidate transitioned from; this is used - # when we trace back the best tag sequence - # - oor_idx saves the best tags candidate transitioned from at the positions - # where mask is 0, i.e. out of range (oor) - - # Viterbi algorithm recursive case: we compute the score of the best tag sequence - # for every possible next tag - for i in range(1, seq_length): - if i == 1: - broadcast_score = score.unsqueeze(-1) - broadcast_emission = emissions[:, i].unsqueeze(1) - # shape: (batch_size, num_tags, num_tags) - next_score = broadcast_score + self.transitions + broadcast_emission - else: - broadcast_score = score.unsqueeze(-1) - broadcast_emission = emissions[:, i].unsqueeze(1).unsqueeze(2) - # shape: (batch_size, num_tags, nbest, num_tags) - next_score = ( - broadcast_score + self.transitions.unsqueeze(1) + broadcast_emission - ) - - # Find the top `nbest` maximum score over all possible current tag - # shape: (batch_size, nbest, num_tags) - next_score, indices = next_score.view(batch_size, -1, self.num_tags).topk( - nbest, dim=1 - ) - - if i == 1: - score = score.unsqueeze(-1).expand(-1, -1, nbest) - indices = indices * nbest - - # convert to shape: (batch_size, num_tags, nbest) - next_score = next_score.transpose(2, 1) - indices = indices.transpose(2, 1) - - # Set score to the next score if this timestep is valid (mask == 1) - # and save the index that produces the next score - # shape: (batch_size, num_tags, nbest) - score = torch.where( - mask[:, i].unsqueeze(-1).unsqueeze(-1).bool(), next_score, score - ) - indices = torch.where( - mask[:, i].unsqueeze(-1).unsqueeze(-1).bool(), indices, oor_idx - ) - history_idx[:, i - 1] = indices - - # End transition score shape: (batch_size, num_tags, nbest) - end_score = score + self.end_transitions.unsqueeze(-1) - _, end_tag = end_score.view(batch_size, -1).topk(nbest, dim=1) - - # shape: (batch_size,) - seq_ends = mask.long().sum(dim=1) - 1 - - # insert the best tag at each sequence end (last position with mask == 1) - history_idx.scatter_( - 1, - seq_ends.view(-1, 1, 1, 1).expand(-1, 1, self.num_tags, nbest), - end_tag.view(-1, 1, 1, nbest).expand(-1, 1, self.num_tags, nbest), - ) - - # The most probable path for each sequence - best_tags_arr = torch.zeros( - (batch_size, seq_length, nbest), dtype=torch.long, device=device - ) - best_tags = ( - torch.arange(nbest, dtype=torch.long, device=device) - .view(1, -1) - .expand(batch_size, -1) - ) - for idx in range(seq_length - 1, -1, -1): - best_tags = torch.gather( - history_idx[:, idx].view(batch_size, -1), 1, best_tags - ) - best_tags_arr[:, idx] = torch.div( - best_tags.data.view(batch_size, -1), nbest, rounding_mode="floor" - ) - - return torch.where(mask.unsqueeze(-1).bool(), best_tags_arr, oor_tag).permute( - 2, 0, 1 - ) - - -class BERT_WHITENING: - def __init__(self): - self.kernel = None - self.bias = None - - def compute_kernel_bias(self, sentence_vec): - vecs = torch.cat(sentence_vec, dim=0) - self.bias = -vecs.mean(dim=0, keepdims=True) - - cov = torch.cov(vecs.T) - u, s, vh = torch.linalg.svd(cov) - W = torch.matmul(u, torch.diag(s**0.5)) - self.kernel = torch.linalg.inv(W.T) - - def save_whiten(self, path): - whiten = {"kernel": self.kernel, "bias": self.bias} - torch.save(path, whiten) - - def load_whiten(self, path): - whiten = torch.load(path) - self.kernel = whiten["kernel"] - self.bias = whiten["bias"] - - def transform_and_normalize(self, vecs): - if not (self.kernel is None or self.bias is None): - vecs = (vecs + self.bias).mm(self.kernel) - return vecs / (vecs**2).sum(axis=1, keepdims=True) ** 0.5 - - -class GlobalPointer(nn.Module): - def __init__( - self, - hidden_size, - heads, - head_size, - RoPE=True, - max_len=512, - use_bias=True, - tril_mask=True, - ): - super().__init__() - self.heads = heads - self.head_size = head_size - self.RoPE = RoPE - self.tril_mask = tril_mask - self.RoPE = RoPE - - self.dense = nn.Linear(hidden_size, heads * head_size * 2, bias=use_bias) - if self.RoPE: - self.position_embedding = RoPEPositionEncoding(max_len, head_size) - - def forward(self, inputs, mask=None): - sequence_output = self.dense(inputs) # [..., heads*head_size*2] - sequence_output = torch.stack( - torch.chunk(sequence_output, self.heads, dim=-1), dim=-2 - ) # [..., heads, head_size*2] - qw, kw = ( - sequence_output[..., : self.head_size], - sequence_output[..., self.head_size :], - ) # [..., heads, head_size] - - if self.RoPE: - qw = self.position_embedding(qw) - kw = self.position_embedding(kw) - - logits = torch.einsum( - "bmhd,bnhd->bhmn", qw, kw - ) # [btz, heads, seq_len, seq_len] - - if mask is not None: - attention_mask1 = 1 - mask.unsqueeze(1).unsqueeze(3) # [btz, 1, seq_len, 1] - attention_mask2 = 1 - mask.unsqueeze(1).unsqueeze(2) # [btz, 1, 1, seq_len] - logits = logits.masked_fill(attention_mask1.bool(), value=-float("inf")) - logits = logits.masked_fill(attention_mask2.bool(), value=-float("inf")) - - if self.tril_mask: - logits = logits - torch.tril(torch.ones_like(logits), -1) * 1e12 - - return logits / self.head_size**0.5 - - -class EfficientGlobalPointer(nn.Module): - def __init__( - self, - hidden_size, - heads, - head_size, - RoPE=True, - max_len=512, - use_bias=True, - tril_mask=True, - ): - super().__init__() - self.heads = heads - self.head_size = head_size - self.RoPE = RoPE - self.tril_mask = tril_mask - self.RoPE = RoPE - - self.p_dense = nn.Linear(hidden_size, head_size * 2, bias=use_bias) - self.q_dense = nn.Linear(head_size * 2, heads * 2, bias=use_bias) - if self.RoPE: - self.position_embedding = RoPEPositionEncoding(max_len, head_size) - - def forward(self, inputs, mask=None): - """inputs: [..., hdsz] - mask: [bez, seq_len], padding部分为0 - """ - sequence_output = self.p_dense(inputs) # [..., head_size*2] - qw, kw = ( - sequence_output[..., : self.head_size], - sequence_output[..., self.head_size :], - ) # [..., head_size] - - if self.RoPE: - qw = self.position_embedding(qw) - kw = self.position_embedding(kw) - - logits = ( - torch.einsum("bmd,bnd->bmn", qw, kw) / self.head_size**0.5 - ) - bias_input = self.q_dense(sequence_output) # [..., heads*2] - bias = torch.stack( - torch.chunk(bias_input, self.heads, dim=-1), dim=-2 - ).transpose( - 1, 2 - ) # [btz, heads, seq_len, 2] - logits = ( - logits.unsqueeze(1) + bias[..., :1] + bias[..., 1:].transpose(2, 3) - ) # [btz, heads, seq_len, seq_len] - - if mask is not None: - attention_mask1 = 1 - mask.unsqueeze(1).unsqueeze(3) # [btz, 1, seq_len, 1] - attention_mask2 = 1 - mask.unsqueeze(1).unsqueeze(2) # [btz, 1, 1, seq_len] - logits = logits.masked_fill(attention_mask1.bool(), value=-float("inf")) - logits = logits.masked_fill(attention_mask2.bool(), value=-float("inf")) - - if self.tril_mask: - logits = logits - torch.tril(torch.ones_like(logits), -1) * 1e12 - - return logits - - -class TplinkerHandshakingKernel(nn.Module): - def __init__(self, hidden_size, shaking_type, inner_enc_type=""): - super().__init__() - self.shaking_type = shaking_type - if shaking_type == "cat": - self.combine_fc = nn.Linear(hidden_size * 2, hidden_size) - elif shaking_type == "cat_plus": - self.combine_fc = nn.Linear(hidden_size * 3, hidden_size) - elif shaking_type == "cln": - self.tp_cln = LayerNorm(hidden_size, conditional_size=hidden_size) - elif shaking_type == "cln_plus": - self.tp_cln = LayerNorm(hidden_size, conditional_size=hidden_size) - self.inner_context_cln = LayerNorm( - hidden_size, conditional_size=hidden_size - ) - - self.inner_enc_type = inner_enc_type - if inner_enc_type == "mix_pooling": - self.lamtha = nn.Parameter(torch.rand(hidden_size)) - elif inner_enc_type == "lstm": - self.inner_context_lstm = nn.LSTM( - hidden_size, - hidden_size, - num_layers=1, - bidirectional=False, - batch_first=True, - ) - - def enc_inner_hiddens(self, seq_hiddens, inner_enc_type="lstm"): - # seq_hiddens: (batch_size, seq_len, hidden_size) - def pool(seqence, pooling_type): - if pooling_type == "mean_pooling": - pooling = torch.mean(seqence, dim=-2) - elif pooling_type == "max_pooling": - pooling, _ = torch.max(seqence, dim=-2) - elif pooling_type == "mix_pooling": - pooling = ( - self.lamtha * torch.mean(seqence, dim=-2) - + (1 - self.lamtha) * torch.max(seqence, dim=-2)[0] - ) - return pooling - - if "pooling" in inner_enc_type: - inner_context = torch.stack( - [ - pool(seq_hiddens[:, : i + 1, :], inner_enc_type) - for i in range(seq_hiddens.size()[1]) - ], - dim=1, - ) - elif inner_enc_type == "lstm": - inner_context, _ = self.inner_context_lstm(seq_hiddens) - - return inner_context - - def forward(self, seq_hiddens): - """ - seq_hiddens: (batch_size, seq_len, hidden_size) - return: - shaking_hiddenss: (batch_size, (1 + seq_len) * seq_len / 2, hidden_size) (32, 5+4+3+2+1, 5) - """ - seq_len = seq_hiddens.size()[-2] - shaking_hiddens_list = [] - for ind in range(seq_len): - hidden_each_step = seq_hiddens[:, ind, :] - visible_hiddens = seq_hiddens[:, ind:, :] # ind: only look back - repeat_hiddens = hidden_each_step[:, None, :].repeat(1, seq_len - ind, 1) - - if self.shaking_type == "cat": - shaking_hiddens = torch.cat([repeat_hiddens, visible_hiddens], dim=-1) - shaking_hiddens = torch.tanh(self.combine_fc(shaking_hiddens)) - elif self.shaking_type == "cat_plus": - inner_context = self.enc_inner_hiddens( - visible_hiddens, self.inner_enc_type - ) - shaking_hiddens = torch.cat( - [repeat_hiddens, visible_hiddens, inner_context], dim=-1 - ) - shaking_hiddens = torch.tanh(self.combine_fc(shaking_hiddens)) - elif self.shaking_type == "cln": - shaking_hiddens = self.tp_cln([visible_hiddens, repeat_hiddens]) - elif self.shaking_type == "cln_plus": - inner_context = self.enc_inner_hiddens( - visible_hiddens, self.inner_enc_type - ) - shaking_hiddens = self.tp_cln([visible_hiddens, repeat_hiddens]) - shaking_hiddens = self.inner_context_cln( - [shaking_hiddens, inner_context] - ) - - shaking_hiddens_list.append(shaking_hiddens) - long_shaking_hiddens = torch.cat(shaking_hiddens_list, dim=1) - return long_shaking_hiddens - -class MixUp(nn.Module): - def __init__(self, method="encoder", alpha=1.0, layer_mix=None): - super().__init__() - assert method in {"embed", "encoder", "hidden", None} - self.method = method - self.alpha = alpha - self.perm_index = None - self.lam = 0 - self.layer_mix = layer_mix - - def get_perm(self, inputs): - if isinstance(inputs, torch.Tensor): - return inputs[self.perm_index] - elif isinstance(inputs, (list, tuple)): - return [ - inp[self.perm_index] if isinstance(inp, torch.Tensor) else inp - for inp in inputs - ] - - def mix_up(self, output, output1): - if isinstance(output, torch.Tensor): - return self.lam * output + (1.0 - self.lam) * output1 - elif isinstance(output, (list, tuple)): - output_final = [] - for i in range(len(output)): - if output[i] is None: # conditional_emb=None - output_final.append(output[i]) - elif (not output[i].requires_grad) and ( - output[i].dtype in {torch.long, torch.int} - ): - output_final.append(torch.max(output[i], output1[i])) - else: - output_final.append( - self.lam * output[i] + (1.0 - self.lam) * output1[i] - ) - return output_final - else: - raise ValueError("Illegal model output") - - def encode(self, model, inputs): - batch_size = inputs[0].shape[0] - device = inputs[0].device - self.lam = np.random.beta(self.alpha, self.alpha) - self.perm_index = torch.randperm(batch_size).to(device) - - if self.method is None: - output = model(inputs) - output1 = self.get_perm(output) - return [output, output1] - - elif self.method == "encoder": - output = model(inputs) - output1 = self.get_perm(output) - output_final = self.mix_up(output, output1) - - elif self.method == "embed": - output = model.apply_embeddings(inputs) - output1 = self.get_perm(output) - output_final = self.mix_up(output, output1) - # Main - output_final = model.apply_main_layers(output_final) - # Final - output_final = model.apply_final_layers(output_final) - - elif self.method == "hidden": - if self.layer_mix is None: - try: - layer_mix = random.randint(0, len(model.encoderLayer)) - except: - warnings.warn("LayerMix random failded") - layer_mix = 0 - else: - layer_mix = self.layer_mix - - def apply_on_layer_end(l_i, output): - if l_i == layer_mix: - output1 = self.get_perm(output) - return self.mix_up(output, output1) - else: - return output - - model.apply_on_layer_end = apply_on_layer_end - output_final = model(inputs) - return output_final - - def forward(self, criterion, y_pred, y_true): - y_true1 = y_true[self.perm_index] - return self.lam * criterion(y_pred, y_true) + (1 - self.lam) * criterion( - y_pred, y_true1 - ) +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import math +import random +import warnings +from typing import List, Optional + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from bert4torch.activations import get_activation +from bert4torch.snippets import get_sinusoid_encoding_table, take_along_dim +from torch.functional import Tensor + + +class LayerNorm(nn.Module): + def __init__( + self, + hidden_size, + eps=1e-12, + conditional_size=False, + weight=True, + bias=True, + norm_mode="normal", + **kwargs, + ): + super(LayerNorm, self).__init__() + + if weight: + self.weight = nn.Parameter(torch.ones(hidden_size)) + if bias: + self.bias = nn.Parameter(torch.zeros(hidden_size)) + self.norm_mode = norm_mode + + self.eps = eps + self.conditional_size = conditional_size + if conditional_size: + self.dense1 = nn.Linear(conditional_size, hidden_size, bias=False) + self.dense1.weight.data.uniform_(0, 0) + self.dense2 = nn.Linear(conditional_size, hidden_size, bias=False) + self.dense2.weight.data.uniform_(0, 0) + + def forward(self, x): + inputs = x[0] + + if self.norm_mode == "rmsnorm": + variance = inputs.to(torch.float32).pow(2).mean(-1, keepdim=True) + o = inputs * torch.rsqrt(variance + self.eps) + else: + u = inputs.mean(-1, keepdim=True) + s = (inputs - u).pow(2).mean(-1, keepdim=True) + o = (inputs - u) / torch.sqrt(s + self.eps) + + if not hasattr(self, "weight"): + self.weight = 1 + if not hasattr(self, "bias"): + self.bias = 0 + + if self.conditional_size: + cond = x[1] + for _ in range(len(inputs.shape) - len(cond.shape)): + cond = cond.unsqueeze(dim=1) + return (self.weight + self.dense1(cond)) * o + ( + self.bias + self.dense2(cond) + ) + else: + return self.weight * o + self.bias + + +class MultiHeadAttentionLayer(nn.Module): + def __init__( + self, + hidden_size, + num_attention_heads, + attention_probs_dropout_prob, + attention_scale=True, + return_attention_scores=False, + bias=True, + **kwargs, + ): + super(MultiHeadAttentionLayer, self).__init__() + + assert hidden_size % num_attention_heads == 0 + + self.hidden_size = hidden_size + self.num_attention_heads = num_attention_heads + self.attention_head_size = int(hidden_size / num_attention_heads) + self.attention_scale = attention_scale + self.return_attention_scores = return_attention_scores + + self.bias = bias + self.q = nn.Linear(hidden_size, hidden_size, bias=bias) + self.k = nn.Linear(hidden_size, hidden_size, bias=bias) + self.v = nn.Linear(hidden_size, hidden_size, bias=bias) + self.o = nn.Linear(hidden_size, hidden_size, bias=bias) + self.dropout = nn.Dropout(attention_probs_dropout_prob) + + self.a_bias, self.p_bias = kwargs.get("a_bias"), kwargs.get("p_bias") + + if self.p_bias == "typical_relative": # nezha + self.relative_positions_encoding = RelativePositionsEncoding( + qlen=kwargs.get("max_position"), + klen=kwargs.get("max_position"), + embedding_size=self.attention_head_size, + max_relative_position=kwargs.get("max_relative_position"), + ) + elif self.p_bias == "rotary": # roformer + self.relative_positions_encoding = RoPEPositionEncoding( + max_position=kwargs.get("max_position"), + embedding_size=self.attention_head_size, + ) + elif self.p_bias == "t5_relative": # t5 + self.relative_positions = RelativePositionsEncodingT5( + qlen=kwargs.get("max_position"), + klen=kwargs.get("max_position"), + relative_attention_num_buckets=kwargs.get( + "relative_attention_num_buckets" + ), + is_decoder=kwargs.get("is_decoder"), + ) + self.relative_positions_encoding = nn.Embedding( + kwargs.get("relative_attention_num_buckets"), self.num_attention_heads + ) + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + ( + self.num_attention_heads, + self.attention_head_size, + ) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + ): + mixed_query_layer = self.q(hidden_states) + if encoder_hidden_states is not None: + mixed_key_layer = self.k(encoder_hidden_states) + mixed_value_layer = self.v(encoder_hidden_states) + attention_mask = encoder_attention_mask + else: + mixed_key_layer = self.k(hidden_states) + mixed_value_layer = self.v(hidden_states) + + query_layer = self.transpose_for_scores(mixed_query_layer) + key_layer = self.transpose_for_scores(mixed_key_layer) + value_layer = self.transpose_for_scores(mixed_value_layer) + + if self.p_bias == "rotary": + query_layer = self.relative_positions_encoding(query_layer) + key_layer = self.relative_positions_encoding(key_layer) + + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + if (self.p_bias == "typical_relative") and hasattr( + self, "relative_positions_encoding" + ): + relations_keys = self.relative_positions_encoding( + attention_scores.shape[-1], attention_scores.shape[-1] + ) + key_position_scores_r_t = torch.einsum( + "bnih,ijh->bnij", query_layer, relations_keys + ) + attention_scores = attention_scores + key_position_scores_r_t + elif (self.p_bias == "t5_relative") and hasattr( + self, "relative_positions_encoding" + ): + relations_keys = self.relative_positions( + attention_scores.shape[-1], attention_scores.shape[-1] + ) + key_position_scores_r_t = ( + self.relative_positions_encoding(relations_keys) + .permute([2, 0, 1]) + .unsqueeze(0) + ) + attention_scores = attention_scores + key_position_scores_r_t + + if self.attention_scale: + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + if attention_mask is not None: + attention_mask = ( + 1.0 - attention_mask + ) * -10000.0 + attention_scores = attention_scores + attention_mask + + attention_probs = F.softmax(attention_scores, dim=-1) + attention_probs = self.dropout(attention_probs) + context_layer = torch.matmul( + attention_probs, value_layer + ) # [batch_size, num_attention_heads, query_len, attention_head_size] + + if (self.p_bias == "typical_relative") and hasattr( + self, "relative_positions_encoding" + ): + relations_values = self.relative_positions_encoding( + attention_scores.shape[-1], attention_scores.shape[-1] + ) + value_position_scores_r_t = torch.einsum( + "bnij,ijh->bnih", attention_probs, relations_values + ) + context_layer = context_layer + value_position_scores_r_t + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + + new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size,) + context_layer = context_layer.view(*new_context_layer_shape) + + if self.return_attention_scores: + return self.o(context_layer), attention_scores + else: + return self.o(context_layer) + + +class PositionWiseFeedForward(nn.Module): + def __init__( + self, + hidden_size, + intermediate_size, + dropout_rate=0.5, + hidden_act="gelu", + is_dropout=False, + bias=True, + **kwargs, + ): + super(PositionWiseFeedForward, self).__init__() + + self.is_dropout = is_dropout + self.intermediate_act_fn = get_activation(hidden_act) + self.intermediateDense = nn.Linear(hidden_size, intermediate_size, bias=bias) + self.outputDense = nn.Linear(intermediate_size, hidden_size, bias=bias) + if self.is_dropout: + self.dropout = nn.Dropout(dropout_rate) + + def forward(self, x): + # x shape: (batch size, seq len, hidden_size) + if self.is_dropout: + x = self.dropout(self.intermediate_act_fn(self.intermediateDense(x))) + else: + x = self.intermediate_act_fn(self.intermediateDense(x)) + + # x shape: (batch size, seq len, intermediate_size) + x = self.outputDense(x) + + # x shape: (batch size, seq len, hidden_size) + return x + + +class GatedAttentionUnit(nn.Module): + def __init__( + self, + hidden_size, + attention_key_size, + intermediate_size, + attention_probs_dropout_prob, + hidden_act, + is_dropout=False, + attention_scale=True, + bias=True, + normalization="softmax_plus", + **kwargs, + ): + super().__init__() + self.intermediate_size = intermediate_size + self.attention_head_size = attention_key_size + self.attention_scale = attention_scale + self.is_dropout = is_dropout + self.normalization = normalization + self.hidden_fn = get_activation(hidden_act) + self.dropout = nn.Dropout(attention_probs_dropout_prob) + self.i_dense = nn.Linear( + hidden_size, self.intermediate_size * 2 + attention_key_size, bias=bias + ) + self.offsetscale = self.OffsetScale(attention_key_size, heads=2, bias=bias) + self.o_dense = nn.Linear(self.intermediate_size, hidden_size, bias=bias) + + self.a_bias, self.p_bias = kwargs.get("a_bias"), kwargs.get("p_bias") + if self.p_bias == "rotary": # RoPE + self.relative_positions_encoding = RoPEPositionEncoding( + max_position=kwargs.get("max_position"), + embedding_size=self.attention_head_size, + ) + + def forward(self, hidden_states, attention_mask): + hidden_states = self.hidden_fn(self.i_dense(hidden_states)) + u, v, qk = hidden_states.split( + [self.intermediate_size, self.intermediate_size, self.attention_head_size], + dim=-1, + ) + q, k = self.offsetscale(qk) + + if self.p_bias == "rotary": + q = self.relative_positions_encoding(q) + k = self.relative_positions_encoding(k) + + # Attention + attention_scores = torch.einsum( + "b i d, b j d -> b i j", q, k + ) # [btz, seq_len, seq_len] + if self.attention_scale: + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + + if attention_mask is not None: + attention_mask = (1.0 - attention_mask) * -1e12 + attention_scores = attention_scores + attention_mask.squeeze(1) + + # 归一化 + attention_scores = self.attention_normalize( + attention_scores, -1, self.normalization + ) + + if self.is_dropout: + attention_scores = self.dropout(attention_scores) + + # 计算输出 + out = self.o_dense( + u * torch.einsum("b i j, b j d -> b i d", attention_scores, v) + ) + return out + + def attention_normalize(self, a, dim=-1, method="softmax"): + if method == "softmax": + return F.softmax(a, dim=dim) + else: + mask = (a > -1e11).float() + l = torch.maximum( + torch.sum(mask, dim=dim, keepdims=True), torch.tensor(1).to(mask) + ) + if method == "squared_relu": + return F.relu(a) ** 2 / l + elif method == "softmax_plus": + return F.softmax( + a * torch.log(l) / torch.log(torch.tensor(512)).to(mask), dim=dim + ) + return a + + class OffsetScale(nn.Module): + def __init__(self, head_size, heads=1, bias=True): + super().__init__() + self.gamma = nn.Parameter(torch.ones(heads, head_size)) + self.bias = bias + if bias: + self.beta = nn.Parameter(torch.zeros(heads, head_size)) + nn.init.normal_(self.gamma, std=0.02) + + def forward(self, x): + out = torch.einsum("... d, h d -> ... h d", x, self.gamma) + if self.bias: + out = out + self.beta + return out.unbind(dim=-2) + + +class BertEmbeddings(nn.Module): + def __init__( + self, + vocab_size, + embedding_size, + hidden_size, + max_position, + segment_vocab_size, + shared_segment_embeddings, + drop_rate, + conditional_size=False, + **kwargs, + ): + super(BertEmbeddings, self).__init__() + self.shared_segment_embeddings = shared_segment_embeddings + self.word_embeddings = nn.Embedding(vocab_size, embedding_size, padding_idx=0) + + if kwargs.get("p_bias") == "sinusoid": + self.position_embeddings = SinusoidalPositionEncoding( + max_position, embedding_size + ) + elif kwargs.get("p_bias") in { + "rotary", + "typical_relative", + "t5_relative", + "other_relative", + }: + pass + elif max_position > 0: + self.position_embeddings = nn.Embedding(max_position, embedding_size) + + if (segment_vocab_size > 0) and (not shared_segment_embeddings): + self.segment_embeddings = nn.Embedding(segment_vocab_size, embedding_size) + + # emb_scale + self.emb_scale = kwargs.get("emb_scale", 1) + + # LayerNorm + self.layerNorm = LayerNorm( + embedding_size, eps=1e-12, conditional_size=conditional_size, **kwargs + ) + self.dropout = nn.Dropout(drop_rate) + + if embedding_size != hidden_size: + self.embedding_hidden_mapping_in = nn.Linear(embedding_size, hidden_size) + + def forward( + self, token_ids, segment_ids=None, conditional_emb=None, additional_embs=None + ): + if (not token_ids.requires_grad) and ( + token_ids.dtype in {torch.long, torch.int} + ): + words_embeddings = self.word_embeddings(token_ids) + else: + words_embeddings = token_ids + + if hasattr(self, "segment_embeddings"): + segment_ids = ( + torch.zeros_like(token_ids) if segment_ids is None else segment_ids + ) + segment_embeddings = self.segment_embeddings(segment_ids) + embeddings = words_embeddings + segment_embeddings + elif self.shared_segment_embeddings: + segment_ids = ( + torch.zeros_like(token_ids) if segment_ids is None else segment_ids + ) + segment_embeddings = self.word_embeddings(segment_ids) + embeddings = words_embeddings + segment_embeddings + else: + embeddings = words_embeddings + + if additional_embs is not None: + for emb in additional_embs: + embeddings += emb + + if hasattr(self, "position_embeddings"): + seq_length = token_ids.size(1) + position_ids = torch.arange( + seq_length, dtype=torch.long, device=token_ids.device + ) + position_ids = position_ids.unsqueeze(0).repeat(token_ids.shape[0], 1) + position_embeddings = self.position_embeddings(position_ids) + embeddings += position_embeddings + + if self.emb_scale != 1: + embeddings = embeddings * self.emb_scale + + if hasattr(self, "layerNorm"): + embeddings = self.layerNorm((embeddings, conditional_emb)) + embeddings = self.dropout(embeddings) + + if hasattr(self, "embedding_hidden_mapping_in"): + embeddings = self.embedding_hidden_mapping_in(embeddings) + return embeddings + + +class BertLayer(nn.Module): + def __init__( + self, + hidden_size, + num_attention_heads, + dropout_rate, + attention_probs_dropout_prob, + intermediate_size, + hidden_act, + is_dropout=False, + conditional_size=False, + **kwargs, + ): + super(BertLayer, self).__init__() + self.multiHeadAttention = MultiHeadAttentionLayer( + hidden_size, num_attention_heads, attention_probs_dropout_prob, **kwargs + ) + self.dropout1 = nn.Dropout(dropout_rate) + self.layerNorm1 = LayerNorm( + hidden_size, eps=1e-12, conditional_size=conditional_size, **kwargs + ) + self.feedForward = PositionWiseFeedForward( + hidden_size, + intermediate_size, + dropout_rate, + hidden_act, + is_dropout=is_dropout, + **kwargs, + ) + self.dropout2 = nn.Dropout(dropout_rate) + self.layerNorm2 = LayerNorm( + hidden_size, eps=1e-12, conditional_size=conditional_size, **kwargs + ) + self.is_decoder = kwargs.get("is_decoder") + if self.is_decoder: + self.crossAttention = MultiHeadAttentionLayer( + hidden_size, num_attention_heads, attention_probs_dropout_prob, **kwargs + ) + self.dropout3 = nn.Dropout(dropout_rate) + self.layerNorm3 = LayerNorm( + hidden_size, eps=1e-12, conditional_size=conditional_size, **kwargs + ) + + def forward( + self, + hidden_states, + attention_mask, + conditional_emb=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + ): + self_attn_output = self.multiHeadAttention( + hidden_states, attention_mask + ) + hidden_states = hidden_states + self.dropout1(self_attn_output) + hidden_states = self.layerNorm1((hidden_states, conditional_emb)) + + # cross attention + if self.is_decoder and encoder_hidden_states is not None: + cross_attn_output = self.crossAttention( + hidden_states, None, encoder_hidden_states, encoder_attention_mask + ) + hidden_states = hidden_states + self.dropout3(cross_attn_output) + hidden_states = self.layerNorm3((hidden_states, conditional_emb)) + + self_attn_output2 = self.feedForward(hidden_states) + hidden_states = hidden_states + self.dropout2(self_attn_output2) + hidden_states = self.layerNorm2((hidden_states, conditional_emb)) + return hidden_states + + +class T5Layer(BertLayer): + def __init__(self, *args, version="t5.1.0", **kwargs): + super().__init__(*args, **kwargs) + + if version.endswith("t5.1.1"): + kwargs["dropout_rate"] = args[2] + kwargs["hidden_act"] = args[5] + self.feedForward = self.T5PositionWiseFeedForward( + hidden_size=args[0], intermediate_size=args[4], **kwargs + ) + + if self.is_decoder and hasattr( + self.crossAttention, "relative_positions_encoding" + ): + del self.crossAttention.relative_positions_encoding + del self.crossAttention.relative_positions + + def forward( + self, + hidden_states, + attention_mask, + conditional_emb=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + ): + x = self.layerNorm1((hidden_states, conditional_emb)) + self_attn_output = self.multiHeadAttention(x, attention_mask) + hidden_states = hidden_states + self.dropout1(self_attn_output) + + # cross attention + if self.is_decoder and encoder_hidden_states is not None: + x = self.layerNorm3((hidden_states, conditional_emb)) + cross_attn_output = self.crossAttention( + x, None, encoder_hidden_states, encoder_attention_mask + ) + hidden_states = hidden_states + self.dropout3(cross_attn_output) + + x = self.layerNorm2((hidden_states, conditional_emb)) + ffn_output = self.feedForward(x) + hidden_states = hidden_states + self.dropout2(ffn_output) + return hidden_states + + class T5PositionWiseFeedForward(PositionWiseFeedForward): + def __init__(self, hidden_size, intermediate_size, **kwargs): + super().__init__(hidden_size, intermediate_size, **kwargs) + self.intermediateDense = nn.Linear( + hidden_size, intermediate_size, bias=False + ) + self.intermediateDense1 = nn.Linear( + hidden_size, intermediate_size, bias=False + ) + self.outputDense = nn.Linear(intermediate_size, hidden_size, bias=False) + + def forward(self, x): + # x shape: (batch size, seq len, hidden_size) + x_gelu = self.intermediate_act_fn(self.intermediateDense(x)) + x_linear = self.intermediateDense1(x) + x = x_gelu * x_linear + if self.is_dropout: + x = self.dropout(x) + + # x shape: (batch size, seq len, intermediate_size) + x = self.outputDense(x) + + # x shape: (batch size, seq len, hidden_size) + return x + + +class XlnetLayer(BertLayer): + def __init__( + self, + hidden_size, + num_attention_heads, + dropout_rate, + attention_probs_dropout_prob, + intermediate_size, + hidden_act, + **kwargs, + ): + super().__init__( + hidden_size, + num_attention_heads, + dropout_rate, + attention_probs_dropout_prob, + intermediate_size, + hidden_act, + **kwargs, + ) + self.pre_lnorm = kwargs.get("pre_lnorm") + self.multiHeadAttention = self.RelPartialLearnableMultiHeadAttn( + hidden_size, + num_attention_heads, + attention_probs_dropout_prob, + bias=False, + **kwargs, + ) + + def forward( + self, + hidden_states, + segment_ids, + pos_emb, + attention_mask, + mems_i, + conditional_emb=None, + ): + hidden_states_cat = ( + torch.cat([mems_i, hidden_states], 1) + if mems_i is not None + else hidden_states + ) + + # Attn + if self.pre_lnorm: + hidden_states_cat = self.layerNorm1((hidden_states_cat, conditional_emb)) + self_attn_output = self.multiHeadAttention( + hidden_states, hidden_states_cat, pos_emb, attention_mask, segment_ids + ) + hidden_states = hidden_states + self.dropout1(self_attn_output) + if not self.pre_lnorm: # post_lnorm + hidden_states = self.layerNorm1((hidden_states, conditional_emb)) + + # FFN + x = ( + self.layerNorm2((hidden_states, conditional_emb)) + if self.pre_lnorm + else hidden_states + ) + self_attn_output2 = self.feedForward(x) + hidden_states = hidden_states + self.dropout2(self_attn_output2) + if not self.pre_lnorm: # post_lnorm + hidden_states = self.layerNorm2((hidden_states, conditional_emb)) + return hidden_states + + class RelPartialLearnableMultiHeadAttn(MultiHeadAttentionLayer): + + def __init__( + self, *args, r_w_bias=None, r_r_bias=None, r_s_bias=None, **kwargs + ): + super().__init__(*args, **kwargs) + segment_vocab_size = kwargs.get("segment_vocab_size") + if r_r_bias is None or r_w_bias is None: # Biases are not shared + self.r_r_bias = nn.Parameter( + torch.FloatTensor( + self.num_attention_heads, self.attention_head_size + ) + ) + self.r_w_bias = nn.Parameter( + torch.FloatTensor( + self.num_attention_heads, self.attention_head_size + ) + ) + if segment_vocab_size > 0: + self.r_s_bias = nn.Parameter( + torch.FloatTensor( + self.num_attention_heads, self.attention_head_size + ) + ) + else: + self.r_r_bias = r_r_bias + self.r_w_bias = r_w_bias + self.r_s_bias = r_s_bias + if segment_vocab_size > 0: + # self.seg_embed = nn.Embedding(segment_vocab_size, self.hidden_size) + self.seg_embed = nn.Parameter( + torch.FloatTensor( + segment_vocab_size, + self.num_attention_heads, + self.attention_head_size, + ) + ) + + self.r = nn.Linear(self.hidden_size, self.hidden_size, bias=self.bias) + self.rel_shift_opt = kwargs.get("rel_shift_opt") + + @staticmethod + def rel_shift(x, zero_triu=False): + q_len, k_len = x.size(2), x.size(-1) + zero_pad = torch.zeros( + (*x.size()[:2], q_len, 1), device=x.device, dtype=x.dtype + ) + x_padded = torch.cat([zero_pad, x], dim=-1) + x_padded = x_padded.view(*x.size()[:2], k_len + 1, q_len) + x = x_padded[:, :, 1:, :].view_as(x) + if zero_triu: + ones = torch.ones((q_len, k_len), device=x.device) + x = x * torch.tril(ones, k_len - q_len)[None, None, :, :] + return x + + @staticmethod + def rel_shift_bnij(x, klen=-1): + x_size = x.shape + x = x.reshape(x_size[0], x_size[1], x_size[3], x_size[2]) + x = x[:, :, 1:, :] + x = x.reshape(x_size[0], x_size[1], x_size[2], x_size[3] - 1) + x = torch.index_select( + x, 3, torch.arange(klen, device=x.device, dtype=torch.long) + ) + return x + + def forward(self, w, cat, r, attention_mask=None, seg_mat=None): + qlen, rlen, bsz = w.size(1), r.size(0), w.size(0) + + mixed_query_layer = self.q(cat)[:, -qlen:, :] + mixed_key_layer = self.k(cat) + mixed_value_layer = self.v(cat) + + w_head_q = self.transpose_for_scores( + mixed_query_layer + ) # [btz, n_head, q_len, d_head] + w_head_k = self.transpose_for_scores( + mixed_key_layer + ) # [btz, n_head, k_len, d_head] + w_head_v = self.transpose_for_scores( + mixed_value_layer + ) # [btz, n_head, k_len, d_head] + + r_head_k = self.r(r) # [hdsz, nhead*headsize] = [r_len, 1, nhead*headsize] + r_head_k = r_head_k.view( + rlen, self.num_attention_heads, self.attention_head_size + ) # rlen x n_head x d_head + + #### compute attention score + rw_head_q = w_head_q + self.r_w_bias.unsqueeze( + 1 + ) # [btz, n_head, q_len, d_head] + AC = torch.einsum( + "bnid,bnjd->bnij", (rw_head_q, w_head_k) + ) # [btz, n_head, q_len, k_len] + + rr_head_q = w_head_q + self.r_r_bias.unsqueeze( + 1 + ) # [btz, n_head, q_len, d_head] + BD = torch.einsum( + "bnid,jnd->bnij", (rr_head_q, r_head_k) + ) # [btz, n_head, q_len, k_len] + BD = ( + self.rel_shift_bnij(BD, klen=AC.shape[3]) + if self.rel_shift_opt == "xlnet" + else self.rel_shift(BD) + ) + + if hasattr(self, "seg_embed") and (self.r_r_bias is not None): + seg_mat = F.one_hot(seg_mat, 2).float() + EF = torch.einsum( + "bnid,snd->ibns", + w_head_q + self.r_s_bias.unsqueeze(1), + self.seg_embed, + ) + EF = torch.einsum("bijs,ibns->bnij", seg_mat, EF) + else: + EF = 0 + + # # [btz, n_head, q_len, k_len] + attention_scores = AC + BD + EF + if self.attention_scale: + attention_scores = attention_scores / math.sqrt( + self.attention_head_size + ) + + if attention_mask is not None and attention_mask.any().item(): + attention_mask = 1.0 - attention_mask + attention_scores = ( + attention_scores.float() + .masked_fill(attention_mask.bool(), -1e30) + .type_as(attention_mask) + ) + + # [btz, n_head, q_len, k_len] + attention_probs = F.softmax(attention_scores, dim=-1) + attention_probs = self.dropout(attention_probs) + context_layer = torch.matmul( + attention_probs, w_head_v + ) # [batch_size, num_attention_heads, query_len, attention_head_size] + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size,) + context_layer = context_layer.view(*new_context_layer_shape) + + if self.return_attention_scores: + return self.o(context_layer), attention_scores + else: + return self.o(context_layer) + + +class AdaptiveEmbedding(nn.Module): + def __init__( + self, + vocab_size, + embedding_size, + hidden_size, + cutoffs, + div_val=1, + sample_softmax=False, + **kwargs, + ): + super().__init__() + self.vocab_size = vocab_size + self.embedding_size = embedding_size + self.cutoffs = cutoffs + [vocab_size] + self.div_val = div_val + self.hidden_size = hidden_size + self.emb_scale = hidden_size**0.5 + self.cutoff_ends = [0] + self.cutoffs + + self.emb_layers = nn.ModuleList() + self.emb_projs = nn.ParameterList() + if div_val == 1: + self.emb_layers.append( + nn.Embedding(vocab_size, embedding_size, sparse=sample_softmax > 0) + ) + if hidden_size != embedding_size: + self.emb_projs.append( + nn.Parameter(torch.FloatTensor(hidden_size, embedding_size)) + ) + else: + for i in range(len(self.cutoffs)): + l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1] + d_emb_i = embedding_size // (div_val**i) + self.emb_layers.append(nn.Embedding(r_idx - l_idx, d_emb_i)) + self.emb_projs.append( + nn.Parameter(torch.FloatTensor(hidden_size, d_emb_i)) + ) + + def forward(self, token_ids): + if self.div_val == 1: + embed = self.emb_layers[0](token_ids) # [btz, seq_len, embedding_size] + if self.hidden_size != self.embedding_size: + embed = nn.functional.linear(embed, self.emb_projs[0]) + else: + param = next(self.parameters()) + inp_flat = token_ids.view(-1) + emb_flat = torch.zeros( + [inp_flat.size(0), self.hidden_size], + dtype=param.dtype, + device=param.device, + ) + for i in range(len(self.cutoffs)): + l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1] + + mask_i = (inp_flat >= l_idx) & (inp_flat < r_idx) + indices_i = mask_i.nonzero().squeeze() + + if indices_i.numel() == 0: + continue + + inp_i = inp_flat.index_select(0, indices_i) - l_idx + emb_i = self.emb_layers[i](inp_i) + emb_i = nn.functional.linear(emb_i, self.emb_projs[i]) + + emb_flat.index_copy_(0, indices_i, emb_i) + + embed_shape = token_ids.size() + (self.hidden_size,) + embed = emb_flat.view(embed_shape) + + embed.mul_(self.emb_scale) + + return embed + + +class Identity(nn.Module): + def __init__(self, *args, **kwargs): + super(Identity, self).__init__() + + def forward(self, *args): + return args[0] + + +class XlnetPositionsEncoding(nn.Module): + def __init__(self, embedding_size): + super().__init__() + self.demb = embedding_size + inv_freq = 1 / ( + 10000 ** (torch.arange(0.0, embedding_size, 2.0) / embedding_size) + ) + self.register_buffer("inv_freq", inv_freq) + + def forward(self, pos_seq): + sinusoid_inp = torch.ger(pos_seq, self.inv_freq) + pos_emb = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=-1) + return pos_emb + + +class RelativePositionsEncoding(nn.Module): + + def __init__(self, qlen, klen, embedding_size, max_relative_position=127): + super(RelativePositionsEncoding, self).__init__() + # 生成相对位置矩阵 + vocab_size = max_relative_position * 2 + 1 + distance_mat = ( + torch.arange(klen)[None, :] - torch.arange(qlen)[:, None] + ) # 列数-行数, [query_len, key_len] + distance_mat_clipped = torch.clamp( + distance_mat, -max_relative_position, max_relative_position + ) + final_mat = distance_mat_clipped + max_relative_position + + embeddings_table = get_sinusoid_encoding_table(vocab_size, embedding_size) + + position_embeddings = nn.Embedding.from_pretrained( + embeddings_table, freeze=True + )(final_mat) + self.register_buffer("position_embeddings", position_embeddings) + + def forward(self, qlen, klen): + return self.position_embeddings[:qlen, :klen, :] + + +class RelativePositionsEncodingT5(nn.Module): + + def __init__(self, qlen, klen, relative_attention_num_buckets, is_decoder=False): + super(RelativePositionsEncodingT5, self).__init__() + context_position = torch.arange(qlen, dtype=torch.long)[:, None] + memory_position = torch.arange(klen, dtype=torch.long)[None, :] + relative_position = memory_position - context_position # shape (qlen, klen) + relative_position = self._relative_position_bucket( + relative_position, # shape (qlen, klen) + bidirectional=not is_decoder, + num_buckets=relative_attention_num_buckets, + ) + self.register_buffer("relative_position", relative_position) + + def forward(self, qlen, klen): + return self.relative_position[:qlen, :klen] + + @staticmethod + def _relative_position_bucket( + relative_position, bidirectional=True, num_buckets=32, max_distance=128 + ): + """直接来源于transformer""" + ret = 0 + n = -relative_position + if bidirectional: + num_buckets //= 2 + ret += (n < 0).to( + torch.long + ) * num_buckets # mtf.to_int32(mtf.less(n, 0)) * num_buckets + n = torch.abs(n) + else: + n = torch.max(n, torch.zeros_like(n)) + # now n is in the range [0, inf) + + # half of the buckets are for exact increments in positions + max_exact = num_buckets // 2 + is_small = n < max_exact + + # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance + val_if_large = max_exact + ( + torch.log(n.float() / max_exact) + / math.log(max_distance / max_exact) + * (num_buckets - max_exact) + ).to(torch.long) + val_if_large = torch.min( + val_if_large, torch.full_like(val_if_large, num_buckets - 1) + ) + + ret += torch.where(is_small, n, val_if_large) + return ret + + +class SinusoidalPositionEncoding(nn.Module): + """定义Sin-Cos位置Embedding""" + + def __init__(self, max_position, embedding_size): + super(SinusoidalPositionEncoding, self).__init__() + self.position_embeddings = nn.Embedding.from_pretrained( + get_sinusoid_encoding_table(max_position, embedding_size), freeze=True + ) + + def forward(self, position_ids): + return self.position_embeddings(position_ids) + + +class RoPEPositionEncoding(nn.Module): + def __init__(self, max_position, embedding_size): + super(RoPEPositionEncoding, self).__init__() + position_embeddings = get_sinusoid_encoding_table( + max_position, embedding_size + ) # [seq_len, hdsz] + cos_position = position_embeddings[:, 1::2].repeat_interleave(2, dim=-1) + sin_position = position_embeddings[:, ::2].repeat_interleave(2, dim=-1) + self.register_buffer("cos_position", cos_position) + self.register_buffer("sin_position", sin_position) + + def forward(self, qw, seq_dim=-2): + seq_len = qw.shape[seq_dim] + qw2 = torch.stack([-qw[..., 1::2], qw[..., ::2]], dim=-1).reshape_as(qw) + return qw * self.cos_position[:seq_len] + qw2 * self.sin_position[:seq_len] + + +class CRF(nn.Module): + def __init__( + self, + num_tags: int, + init_transitions: Optional[List[np.ndarray]] = None, + freeze=False, + ) -> None: + if num_tags <= 0: + raise ValueError(f"invalid number of tags: {num_tags}") + super().__init__() + self.num_tags = num_tags + if (init_transitions is None) and (not freeze): + self.start_transitions = nn.Parameter(torch.empty(num_tags)) + self.end_transitions = nn.Parameter(torch.empty(num_tags)) + self.transitions = nn.Parameter(torch.empty(num_tags, num_tags)) + nn.init.uniform_(self.start_transitions, -0.1, 0.1) + nn.init.uniform_(self.end_transitions, -0.1, 0.1) + nn.init.uniform_(self.transitions, -0.1, 0.1) + elif init_transitions is not None: + transitions = torch.tensor(init_transitions[0], dtype=torch.float) + start_transitions = torch.tensor(init_transitions[1], dtype=torch.float) + end_transitions = torch.tensor(init_transitions[2], dtype=torch.float) + + if not freeze: + self.transitions = nn.Parameter(transitions) + self.start_transitions = nn.Parameter(start_transitions) + self.end_transitions = nn.Parameter(end_transitions) + else: + self.register_buffer("transitions", transitions) + self.register_buffer("start_transitions", start_transitions) + self.register_buffer("end_transitions", end_transitions) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(num_tags={self.num_tags})" + + def forward( + self, + emissions: torch.Tensor, + mask: torch.ByteTensor, + tags: torch.LongTensor, + reduction: str = "mean", + ) -> torch.Tensor: + """Compute the conditional log likelihood of a sequence of tags given emission scores. + emissions: [btz, seq_len, num_tags] + mask: [btz, seq_len] + tags: [btz, seq_len] + """ + if reduction not in ("none", "sum", "mean", "token_mean"): + raise ValueError(f"invalid reduction: {reduction}") + if mask.dtype != torch.uint8: + mask = mask.byte() + self._validate(emissions, tags=tags, mask=mask) + + # shape: (batch_size,) + numerator = self._compute_score(emissions, tags, mask) + # shape: (batch_size,) + denominator = self._compute_normalizer(emissions, mask) + # shape: (batch_size,) + llh = denominator - numerator + + if reduction == "none": + return llh + if reduction == "sum": + return llh.sum() + if reduction == "mean": + return llh.mean() + return llh.sum() / mask.float().sum() + + def decode( + self, + emissions: torch.Tensor, + mask: Optional[torch.ByteTensor] = None, + nbest: Optional[int] = None, + pad_tag: Optional[int] = None, + ) -> List[List[List[int]]]: + """Find the most likely tag sequence using Viterbi algorithm.""" + if nbest is None: + nbest = 1 + if mask is None: + mask = torch.ones( + emissions.shape[:2], dtype=torch.uint8, device=emissions.device + ) + if mask.dtype != torch.uint8: + mask = mask.byte() + self._validate(emissions, mask=mask) + + best_path = self._viterbi_decode_nbest(emissions, mask, nbest, pad_tag) + return best_path[0] if nbest == 1 else best_path + + def _validate( + self, + emissions: torch.Tensor, + tags: Optional[torch.LongTensor] = None, + mask: Optional[torch.ByteTensor] = None, + ) -> None: + if emissions.dim() != 3: + raise ValueError( + f"emissions must have dimension of 3, got {emissions.dim()}" + ) + if emissions.size(2) != self.num_tags: + raise ValueError( + f"expected last dimension of emissions is {self.num_tags}, " + f"got {emissions.size(2)}" + ) + if tags is not None: + if emissions.shape[:2] != tags.shape: + raise ValueError( + "the first two dimensions of emissions and tags must match, " + f"got {tuple(emissions.shape[:2])} and {tuple(tags.shape)}" + ) + if mask is not None: + if emissions.shape[:2] != mask.shape: + raise ValueError( + "the first two dimensions of emissions and mask must match, " + f"got {tuple(emissions.shape[:2])} and {tuple(mask.shape)}" + ) + no_empty_seq_bf = mask[:, 0].all() + if not no_empty_seq_bf: + raise ValueError("mask of the first timestep must all be on") + + def _compute_score( + self, emissions: torch.Tensor, tags: torch.LongTensor, mask: torch.ByteTensor + ) -> torch.Tensor: + # emissions: (batch_size, seq_length, num_tags) + # tags: (batch_size, seq_length) + # mask: (batch_size, seq_length) + batch_size, seq_length = tags.shape + mask = mask.float() + + # Start transition score and first emission + # shape: (batch_size,) + score = self.start_transitions[tags[:, 0]] + score += emissions[torch.arange(batch_size), 0, tags[:, 0]] + + for i in range(1, seq_length): + # Transition score to next tag, only added if next timestep is valid (mask == 1) + # shape: (batch_size,) + score += self.transitions[tags[:, i - 1], tags[:, i]] * mask[:, i] + # Emission score for next tag, only added if next timestep is valid (mask == 1) + # shape: (batch_size,) + score += emissions[torch.arange(batch_size), i, tags[:, i]] * mask[:, i] + + # End transition score + # shape: (batch_size,) + seq_ends = mask.long().sum(dim=1) - 1 + # shape: (batch_size,) + last_tags = tags[torch.arange(batch_size), seq_ends] + # shape: (batch_size,) + score += self.end_transitions[last_tags] + + return score + + def _compute_normalizer( + self, emissions: torch.Tensor, mask: torch.ByteTensor + ) -> torch.Tensor: + # emissions: (batch_size, seq_length, num_tags) + # mask: (batch_size, seq_length) + seq_length = emissions.size(1) + + # Start transition score and first emission; score has size of + # (batch_size, num_tags) where for each batch, the j-th column stores + # the score that the first timestep has tag j + # shape: (batch_size, num_tags) + score = self.start_transitions + emissions[:, 0] + + for i in range(1, seq_length): + # Broadcast score for every possible next tag + # shape: (batch_size, num_tags, 1) + broadcast_score = score.unsqueeze(2) + + # Broadcast emission score for every possible current tag + # shape: (batch_size, 1, num_tags) + broadcast_emissions = emissions[:, i].unsqueeze(1) + + # Compute the score tensor of size (batch_size, num_tags, num_tags) where + # for each sample, entry at row i and column j stores the sum of scores of all + # possible tag sequences so far that end with transitioning from tag i to tag j + # and emitting + # shape: (batch_size, num_tags, num_tags) + next_score = broadcast_score + self.transitions + broadcast_emissions + + # Sum over all possible current tags, but we're in score space, so a sum + # becomes a log-sum-exp: for each sample, entry i stores the sum of scores of + # all possible tag sequences so far, that end in tag i + # shape: (batch_size, num_tags) + next_score = torch.logsumexp(next_score, dim=1) + + # Set score to the next score if this timestep is valid (mask == 1) + # shape: (batch_size, num_tags) + score = torch.where(mask[:, i].unsqueeze(1).bool(), next_score, score) + + # End transition score + # shape: (batch_size, num_tags) + score += self.end_transitions + + # Sum (log-sum-exp) over all possible tags + # shape: (batch_size,) + return torch.logsumexp(score, dim=1) + + def _viterbi_decode_nbest( + self, + emissions: torch.FloatTensor, + mask: torch.ByteTensor, + nbest: int, + pad_tag: Optional[int] = None, + ) -> List[List[List[int]]]: + # emissions: (batch_size, seq_length, num_tags) + # mask: (batch_size, seq_length) + # return: (nbest, batch_size, seq_length) + if pad_tag is None: + pad_tag = 0 + + device = emissions.device + batch_size, seq_length = mask.shape + + # Start transition and first emission + # shape: (batch_size, num_tags) + score = self.start_transitions + emissions[:, 0] + history_idx = torch.zeros( + (batch_size, seq_length, self.num_tags, nbest), + dtype=torch.long, + device=device, + ) + oor_idx = torch.zeros( + (batch_size, self.num_tags, nbest), dtype=torch.long, device=device + ) + oor_tag = torch.full( + (batch_size, seq_length, nbest), pad_tag, dtype=torch.long, device=device + ) + + # - score is a tensor of size (batch_size, num_tags) where for every batch, + # value at column j stores the score of the best tag sequence so far that ends + # with tag j + # - history_idx saves where the best tags candidate transitioned from; this is used + # when we trace back the best tag sequence + # - oor_idx saves the best tags candidate transitioned from at the positions + # where mask is 0, i.e. out of range (oor) + + # Viterbi algorithm recursive case: we compute the score of the best tag sequence + # for every possible next tag + for i in range(1, seq_length): + if i == 1: + broadcast_score = score.unsqueeze(-1) + broadcast_emission = emissions[:, i].unsqueeze(1) + # shape: (batch_size, num_tags, num_tags) + next_score = broadcast_score + self.transitions + broadcast_emission + else: + broadcast_score = score.unsqueeze(-1) + broadcast_emission = emissions[:, i].unsqueeze(1).unsqueeze(2) + # shape: (batch_size, num_tags, nbest, num_tags) + next_score = ( + broadcast_score + self.transitions.unsqueeze(1) + broadcast_emission + ) + + # Find the top `nbest` maximum score over all possible current tag + # shape: (batch_size, nbest, num_tags) + next_score, indices = next_score.view(batch_size, -1, self.num_tags).topk( + nbest, dim=1 + ) + + if i == 1: + score = score.unsqueeze(-1).expand(-1, -1, nbest) + indices = indices * nbest + + # convert to shape: (batch_size, num_tags, nbest) + next_score = next_score.transpose(2, 1) + indices = indices.transpose(2, 1) + + # Set score to the next score if this timestep is valid (mask == 1) + # and save the index that produces the next score + # shape: (batch_size, num_tags, nbest) + score = torch.where( + mask[:, i].unsqueeze(-1).unsqueeze(-1).bool(), next_score, score + ) + indices = torch.where( + mask[:, i].unsqueeze(-1).unsqueeze(-1).bool(), indices, oor_idx + ) + history_idx[:, i - 1] = indices + + # End transition score shape: (batch_size, num_tags, nbest) + end_score = score + self.end_transitions.unsqueeze(-1) + _, end_tag = end_score.view(batch_size, -1).topk(nbest, dim=1) + + # shape: (batch_size,) + seq_ends = mask.long().sum(dim=1) - 1 + + # insert the best tag at each sequence end (last position with mask == 1) + history_idx.scatter_( + 1, + seq_ends.view(-1, 1, 1, 1).expand(-1, 1, self.num_tags, nbest), + end_tag.view(-1, 1, 1, nbest).expand(-1, 1, self.num_tags, nbest), + ) + + # The most probable path for each sequence + best_tags_arr = torch.zeros( + (batch_size, seq_length, nbest), dtype=torch.long, device=device + ) + best_tags = ( + torch.arange(nbest, dtype=torch.long, device=device) + .view(1, -1) + .expand(batch_size, -1) + ) + for idx in range(seq_length - 1, -1, -1): + best_tags = torch.gather( + history_idx[:, idx].view(batch_size, -1), 1, best_tags + ) + best_tags_arr[:, idx] = torch.div( + best_tags.data.view(batch_size, -1), nbest, rounding_mode="floor" + ) + + return torch.where(mask.unsqueeze(-1).bool(), best_tags_arr, oor_tag).permute( + 2, 0, 1 + ) + + +class BERT_WHITENING: + def __init__(self): + self.kernel = None + self.bias = None + + def compute_kernel_bias(self, sentence_vec): + vecs = torch.cat(sentence_vec, dim=0) + self.bias = -vecs.mean(dim=0, keepdims=True) + + cov = torch.cov(vecs.T) + u, s, vh = torch.linalg.svd(cov) + W = torch.matmul(u, torch.diag(s**0.5)) + self.kernel = torch.linalg.inv(W.T) + + def save_whiten(self, path): + whiten = {"kernel": self.kernel, "bias": self.bias} + torch.save(path, whiten) + + def load_whiten(self, path): + whiten = torch.load(path) + self.kernel = whiten["kernel"] + self.bias = whiten["bias"] + + def transform_and_normalize(self, vecs): + if not (self.kernel is None or self.bias is None): + vecs = (vecs + self.bias).mm(self.kernel) + return vecs / (vecs**2).sum(axis=1, keepdims=True) ** 0.5 + + +class GlobalPointer(nn.Module): + def __init__( + self, + hidden_size, + heads, + head_size, + RoPE=True, + max_len=512, + use_bias=True, + tril_mask=True, + ): + super().__init__() + self.heads = heads + self.head_size = head_size + self.RoPE = RoPE + self.tril_mask = tril_mask + self.RoPE = RoPE + + self.dense = nn.Linear(hidden_size, heads * head_size * 2, bias=use_bias) + if self.RoPE: + self.position_embedding = RoPEPositionEncoding(max_len, head_size) + + def forward(self, inputs, mask=None): + sequence_output = self.dense(inputs) # [..., heads*head_size*2] + sequence_output = torch.stack( + torch.chunk(sequence_output, self.heads, dim=-1), dim=-2 + ) # [..., heads, head_size*2] + qw, kw = ( + sequence_output[..., : self.head_size], + sequence_output[..., self.head_size :], + ) # [..., heads, head_size] + + if self.RoPE: + qw = self.position_embedding(qw) + kw = self.position_embedding(kw) + + logits = torch.einsum( + "bmhd,bnhd->bhmn", qw, kw + ) # [btz, heads, seq_len, seq_len] + + if mask is not None: + attention_mask1 = 1 - mask.unsqueeze(1).unsqueeze(3) # [btz, 1, seq_len, 1] + attention_mask2 = 1 - mask.unsqueeze(1).unsqueeze(2) # [btz, 1, 1, seq_len] + logits = logits.masked_fill(attention_mask1.bool(), value=-float("inf")) + logits = logits.masked_fill(attention_mask2.bool(), value=-float("inf")) + + if self.tril_mask: + logits = logits - torch.tril(torch.ones_like(logits), -1) * 1e12 + + return logits / self.head_size**0.5 + + +class EfficientGlobalPointer(nn.Module): + def __init__( + self, + hidden_size, + heads, + head_size, + RoPE=True, + max_len=512, + use_bias=True, + tril_mask=True, + ): + super().__init__() + self.heads = heads + self.head_size = head_size + self.RoPE = RoPE + self.tril_mask = tril_mask + self.RoPE = RoPE + + self.p_dense = nn.Linear(hidden_size, head_size * 2, bias=use_bias) + self.q_dense = nn.Linear(head_size * 2, heads * 2, bias=use_bias) + if self.RoPE: + self.position_embedding = RoPEPositionEncoding(max_len, head_size) + + def forward(self, inputs, mask=None): + """inputs: [..., hdsz] + mask: [bez, seq_len], padding部分为0 + """ + sequence_output = self.p_dense(inputs) # [..., head_size*2] + qw, kw = ( + sequence_output[..., : self.head_size], + sequence_output[..., self.head_size :], + ) # [..., head_size] + + if self.RoPE: + qw = self.position_embedding(qw) + kw = self.position_embedding(kw) + + logits = ( + torch.einsum("bmd,bnd->bmn", qw, kw) / self.head_size**0.5 + ) + bias_input = self.q_dense(sequence_output) # [..., heads*2] + bias = torch.stack( + torch.chunk(bias_input, self.heads, dim=-1), dim=-2 + ).transpose( + 1, 2 + ) # [btz, heads, seq_len, 2] + logits = ( + logits.unsqueeze(1) + bias[..., :1] + bias[..., 1:].transpose(2, 3) + ) # [btz, heads, seq_len, seq_len] + + if mask is not None: + attention_mask1 = 1 - mask.unsqueeze(1).unsqueeze(3) # [btz, 1, seq_len, 1] + attention_mask2 = 1 - mask.unsqueeze(1).unsqueeze(2) # [btz, 1, 1, seq_len] + logits = logits.masked_fill(attention_mask1.bool(), value=-float("inf")) + logits = logits.masked_fill(attention_mask2.bool(), value=-float("inf")) + + if self.tril_mask: + logits = logits - torch.tril(torch.ones_like(logits), -1) * 1e12 + + return logits + + +class TplinkerHandshakingKernel(nn.Module): + def __init__(self, hidden_size, shaking_type, inner_enc_type=""): + super().__init__() + self.shaking_type = shaking_type + if shaking_type == "cat": + self.combine_fc = nn.Linear(hidden_size * 2, hidden_size) + elif shaking_type == "cat_plus": + self.combine_fc = nn.Linear(hidden_size * 3, hidden_size) + elif shaking_type == "cln": + self.tp_cln = LayerNorm(hidden_size, conditional_size=hidden_size) + elif shaking_type == "cln_plus": + self.tp_cln = LayerNorm(hidden_size, conditional_size=hidden_size) + self.inner_context_cln = LayerNorm( + hidden_size, conditional_size=hidden_size + ) + + self.inner_enc_type = inner_enc_type + if inner_enc_type == "mix_pooling": + self.lamtha = nn.Parameter(torch.rand(hidden_size)) + elif inner_enc_type == "lstm": + self.inner_context_lstm = nn.LSTM( + hidden_size, + hidden_size, + num_layers=1, + bidirectional=False, + batch_first=True, + ) + + def enc_inner_hiddens(self, seq_hiddens, inner_enc_type="lstm"): + # seq_hiddens: (batch_size, seq_len, hidden_size) + def pool(seqence, pooling_type): + if pooling_type == "mean_pooling": + pooling = torch.mean(seqence, dim=-2) + elif pooling_type == "max_pooling": + pooling, _ = torch.max(seqence, dim=-2) + elif pooling_type == "mix_pooling": + pooling = ( + self.lamtha * torch.mean(seqence, dim=-2) + + (1 - self.lamtha) * torch.max(seqence, dim=-2)[0] + ) + return pooling + + if "pooling" in inner_enc_type: + inner_context = torch.stack( + [ + pool(seq_hiddens[:, : i + 1, :], inner_enc_type) + for i in range(seq_hiddens.size()[1]) + ], + dim=1, + ) + elif inner_enc_type == "lstm": + inner_context, _ = self.inner_context_lstm(seq_hiddens) + + return inner_context + + def forward(self, seq_hiddens): + """ + seq_hiddens: (batch_size, seq_len, hidden_size) + return: + shaking_hiddenss: (batch_size, (1 + seq_len) * seq_len / 2, hidden_size) (32, 5+4+3+2+1, 5) + """ + seq_len = seq_hiddens.size()[-2] + shaking_hiddens_list = [] + for ind in range(seq_len): + hidden_each_step = seq_hiddens[:, ind, :] + visible_hiddens = seq_hiddens[:, ind:, :] # ind: only look back + repeat_hiddens = hidden_each_step[:, None, :].repeat(1, seq_len - ind, 1) + + if self.shaking_type == "cat": + shaking_hiddens = torch.cat([repeat_hiddens, visible_hiddens], dim=-1) + shaking_hiddens = torch.tanh(self.combine_fc(shaking_hiddens)) + elif self.shaking_type == "cat_plus": + inner_context = self.enc_inner_hiddens( + visible_hiddens, self.inner_enc_type + ) + shaking_hiddens = torch.cat( + [repeat_hiddens, visible_hiddens, inner_context], dim=-1 + ) + shaking_hiddens = torch.tanh(self.combine_fc(shaking_hiddens)) + elif self.shaking_type == "cln": + shaking_hiddens = self.tp_cln([visible_hiddens, repeat_hiddens]) + elif self.shaking_type == "cln_plus": + inner_context = self.enc_inner_hiddens( + visible_hiddens, self.inner_enc_type + ) + shaking_hiddens = self.tp_cln([visible_hiddens, repeat_hiddens]) + shaking_hiddens = self.inner_context_cln( + [shaking_hiddens, inner_context] + ) + + shaking_hiddens_list.append(shaking_hiddens) + long_shaking_hiddens = torch.cat(shaking_hiddens_list, dim=1) + return long_shaking_hiddens + +class MixUp(nn.Module): + def __init__(self, method="encoder", alpha=1.0, layer_mix=None): + super().__init__() + assert method in {"embed", "encoder", "hidden", None} + self.method = method + self.alpha = alpha + self.perm_index = None + self.lam = 0 + self.layer_mix = layer_mix + + def get_perm(self, inputs): + if isinstance(inputs, torch.Tensor): + return inputs[self.perm_index] + elif isinstance(inputs, (list, tuple)): + return [ + inp[self.perm_index] if isinstance(inp, torch.Tensor) else inp + for inp in inputs + ] + + def mix_up(self, output, output1): + if isinstance(output, torch.Tensor): + return self.lam * output + (1.0 - self.lam) * output1 + elif isinstance(output, (list, tuple)): + output_final = [] + for i in range(len(output)): + if output[i] is None: # conditional_emb=None + output_final.append(output[i]) + elif (not output[i].requires_grad) and ( + output[i].dtype in {torch.long, torch.int} + ): + output_final.append(torch.max(output[i], output1[i])) + else: + output_final.append( + self.lam * output[i] + (1.0 - self.lam) * output1[i] + ) + return output_final + else: + raise ValueError("Illegal model output") + + def encode(self, model, inputs): + batch_size = inputs[0].shape[0] + device = inputs[0].device + self.lam = np.random.beta(self.alpha, self.alpha) + self.perm_index = torch.randperm(batch_size).to(device) + + if self.method is None: + output = model(inputs) + output1 = self.get_perm(output) + return [output, output1] + + elif self.method == "encoder": + output = model(inputs) + output1 = self.get_perm(output) + output_final = self.mix_up(output, output1) + + elif self.method == "embed": + output = model.apply_embeddings(inputs) + output1 = self.get_perm(output) + output_final = self.mix_up(output, output1) + # Main + output_final = model.apply_main_layers(output_final) + # Final + output_final = model.apply_final_layers(output_final) + + elif self.method == "hidden": + if self.layer_mix is None: + try: + layer_mix = random.randint(0, len(model.encoderLayer)) + except: + warnings.warn("LayerMix random failded") + layer_mix = 0 + else: + layer_mix = self.layer_mix + + def apply_on_layer_end(l_i, output): + if l_i == layer_mix: + output1 = self.get_perm(output) + return self.mix_up(output, output1) + else: + return output + + model.apply_on_layer_end = apply_on_layer_end + output_final = model(inputs) + return output_final + + def forward(self, criterion, y_pred, y_true): + y_true1 = y_true[self.perm_index] + return self.lam * criterion(y_pred, y_true) + (1 - self.lam) * criterion( + y_pred, y_true1 + ) diff --git a/models/nlp/plm/bert_base_ner/igie/Int8QAT/bert4torch/models.py b/models/nlp/plm/bert_base_ner/igie/Int8QAT/bert4torch/models.py index 5bf320cef6448392ffd9ac94763cd1379feeb2a7..8b79dac0adc63f6085785835bdee1f05aacdc29c 100644 --- a/models/nlp/plm/bert_base_ner/igie/Int8QAT/bert4torch/models.py +++ b/models/nlp/plm/bert_base_ner/igie/Int8QAT/bert4torch/models.py @@ -1,2363 +1,2363 @@ -# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. -# All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); you may -# not use this file except in compliance with the License. You may obtain -# a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -# License for the specific language governing permissions and limitations -# under the License. - -import copy -import json -import re -import warnings - -import torch -import torch.nn as nn -from bert4torch.activations import get_activation -from bert4torch.layers import ( - AdaptiveEmbedding, - BertEmbeddings, - BertLayer, - GatedAttentionUnit, - Identity, - LayerNorm, - T5Layer, - XlnetLayer, - XlnetPositionsEncoding, -) -from bert4torch.snippets import ( - FGM, - PGD, - VAT, - EarlyStopping, - IterDataset, - ProgbarLogger, - delete_arguments, - get_kw, - insert_arguments, - metric_mapping, - search_layer, - take_along_dim, -) - - -class BaseModel(nn.Module): - def __init__(self): - super(BaseModel, self).__init__() - ( - self.global_step, - self.local_step, - self.total_steps, - self.epoch, - self.train_dataloader, - ) = (0, 0, 0, 0, None) - self.callbacks = [] - - def compile( - self, - loss, - optimizer, - scheduler=None, - max_grad_norm=None, - use_amp=False, - metrics=None, - adversarial_train={"name": ""}, - ): - self.criterion = loss - self.optimizer = optimizer - self.scheduler = scheduler - self.max_grad_norm = max_grad_norm - self.use_amp = use_amp - if use_amp: - assert adversarial_train["name"] not in { - "vat", - "gradient_penalty", - }, "Amp and adversarial_train both run is not supported in current version" - from torch.cuda.amp import autocast - - self.autocast = autocast - self.scaler = torch.cuda.amp.GradScaler() - - if metrics is None: - metrics = [] - self.metrics = ["loss"] + [i for i in metrics if i != "loss"] - - # 对抗训练 - self.adversarial = adversarial_train - self.adversarial_initialize() - - def adversarial_initialize(self): - assert self.adversarial["name"] in { - "", - "fgm", - "pgd", - "vat", - "gradient_penalty", - }, "adversarial_train support fgm, pgd, vat and gradient_penalty mode" - self.adversarial["epsilon"] = self.adversarial.get("epsilon", 1.0) - self.adversarial["emb_name"] = self.adversarial.get( - "emb_name", "word_embeddings" - ) - - if self.adversarial["name"] == "fgm": - self.ad_train = FGM(self) - elif self.adversarial["name"] == "pgd": - self.adversarial["K"] = self.adversarial.get("K", 3) # 步数 - self.adversarial["alpha"] = self.adversarial.get("alpha", 0.3) # 学习率 - self.ad_train = PGD(self) - elif self.adversarial["name"] == "gradient_penalty": - pass - elif self.adversarial["name"] == "vat": - self.adversarial["K"] = self.adversarial.get("K", 3) - self.adversarial["noise_var"] = self.adversarial.get( - "noise_var", 1e-5 - ) - self.adversarial["noise_gamma"] = self.adversarial.get( - "noise_gamma", 1e-6 - ) - self.adversarial["adv_step_size"] = self.adversarial.get( - "adv_step_size", 1e-3 - ) - self.adversarial["adv_alpha"] = self.adversarial.get( - "adv_alpha", 1 - ) - self.adversarial["norm_type"] = self.adversarial.get( - "norm_type", "l2" - ) - self.ad_train = VAT(self, **self.adversarial) - - def adversarial_training( - self, train_X, train_y, output, loss, loss_detail, grad_accumulation_steps - ): - """对抗训练""" - if self.adversarial["name"] == "fgm": - self.ad_train.attack(**self.adversarial) - output, loss, loss_detail = self.train_step( - train_X, train_y, grad_accumulation_steps - ) - loss.backward() - self.ad_train.restore(**self.adversarial) - elif self.adversarial["name"] == "pgd": - self.ad_train.backup_grad() - for t in range(self.adversarial["K"]): - self.ad_train.attack(**self.adversarial, is_first_attack=(t == 0)) - if t != self.adversarial["K"] - 1: - self.optimizer.zero_grad() - else: - self.ad_train.restore_grad() - output, loss, loss_detail = self.train_step( - train_X, train_y, grad_accumulation_steps - ) - loss.backward() - self.ad_train.restore(**self.adversarial) - elif self.adversarial["name"] == "gradient_penalty": - para = search_layer(self, self.adversarial["emb_name"], retrun_first=True) - gp = (para.grad**2).sum() - loss += 0.5 * gp * self.adversarial["epsilon"] - loss.backward() - elif self.adversarial["name"] == "vat": - logit = output[0] if isinstance(output, (list, tuple)) else output - adv_loss = self.ad_train.virtual_adversarial_training(train_X, logit) - loss_detail.update({"loss_sup": loss.item(), "loss_unsup": adv_loss}) - loss += adv_loss if adv_loss else 0 - loss.backward() - - return loss, loss_detail - - def train_step(self, train_X, train_y, grad_accumulation_steps): - - def args_segmentate(train_X): - if isinstance(train_X, torch.Tensor): - pass - elif isinstance(self, (BaseModelDP, BaseModelDDP)): - if self.module.forward.__code__.co_argcount >= 3: - return True - elif self.forward.__code__.co_argcount >= 3: - return True - return False - - if self.use_amp: - with self.autocast(): - output = ( - self.forward(*train_X) - if args_segmentate(train_X) - else self.forward(train_X) - ) - loss_detail = self.criterion(output, train_y) - else: - output = ( - self.forward(*train_X) - if args_segmentate(train_X) - else self.forward(train_X) - ) - loss_detail = self.criterion(output, train_y) - - if isinstance(loss_detail, torch.Tensor): - loss = loss_detail - loss_detail = {} - elif isinstance(loss_detail, dict): - loss = loss_detail["loss"] - del loss_detail["loss"] - elif isinstance(loss_detail, (tuple, list)): - loss = loss_detail[0] - loss_detail = { - f"loss{i}": v for i, v in enumerate(loss_detail[1:], start=1) - } - else: - raise ValueError("Return loss only support Tensor/dict/tuple/list format") - # 梯度累积 - loss = loss / grad_accumulation_steps if grad_accumulation_steps > 1 else loss - return output, loss, loss_detail - - def callback_fun(self, mode, logs={}): - if ( - isinstance(self, BaseModelDDP) - and self.master_rank != torch.distributed.get_rank() - ): - return - - if mode == "train_begin": - for callback in self.callbacks: - callback.on_train_begin() - elif mode == "epoch_begin": - for callback in self.callbacks: - callback.on_epoch_begin(self.global_step, self.epoch, logs) - elif mode == "batch_begin": - for callback in self.callbacks: - callback.on_batch_begin(self.global_step, self.local_step, logs) - elif mode == "batch_end": - for callback in self.callbacks: - callback.on_batch_end(self.global_step, self.local_step, logs) - elif mode == "epoch_end": - for callback in self.callbacks: - callback.on_epoch_end(self.global_step, self.epoch, logs) - elif mode == "train_end": - for callback in self.callbacks: - callback.on_train_end() - elif mode == "dataloader_end": - for callback in self.callbacks: - callback.on_dataloader_end() - - def fit( - self, - train_dataloader, - steps_per_epoch=None, - epochs=1, - grad_accumulation_steps=1, - callbacks=[], - ): - if isinstance(train_dataloader.dataset, IterDataset): - assert ( - steps_per_epoch is not None - ), "IterDataset should specify steps_per_epoch" - steps_per_epoch = ( - len(train_dataloader) if steps_per_epoch is None else steps_per_epoch - ) - self.total_steps = steps_per_epoch * epochs - self.global_step = 0 - self.train_dataloader = train_dataloader - train_dataloader_iter = iter(self.train_dataloader) - - self.callbacks = [ProgbarLogger(epochs, steps_per_epoch, self.metrics)] + ( - callbacks if isinstance(callbacks, (list, tuple)) else [callbacks] - ) - self.callback_fun("train_begin") - - self.bti = 0 - for epoch in range(epochs): - if isinstance( - self.train_dataloader.sampler, - torch.utils.data.distributed.DistributedSampler, - ): - self.train_dataloader.sampler.set_epoch(epoch) - self.epoch = epoch - self.callback_fun("epoch_begin") - for local_step in range(steps_per_epoch): - self.local_step = local_step - try: - batch = next(train_dataloader_iter) - except StopIteration: - self.callback_fun( - "dataloader_end" - ) - train_dataloader_iter = iter( - self.train_dataloader - ) - self.bti = 0 - batch = next(train_dataloader_iter) - train_X, train_y = batch - - if isinstance(train_X, (list, tuple)): - if isinstance(train_X[0], (list, tuple)): - btz = train_X[0][0].size(0) - else: - btz = train_X[0].size(0) - elif isinstance(train_X, torch.Tensor): - btz = train_X.size(0) - else: - raise ValueError("Input only support [list, tuple, tensor]") - logs = {"batch": self.local_step, "size": btz} - self.callback_fun("batch_begin", logs) - - self.train() - output, loss, loss_detail = self.train_step( - train_X, train_y, grad_accumulation_steps - ) - - retain_graph = ( - True - if self.adversarial["name"] in {"gradient_penalty", "vat"} - else False - ) - if self.use_amp: - scale_before_step = self.scaler.get_scale() - self.scaler.scale(loss).backward(retain_graph=retain_graph) - else: - loss.backward(retain_graph=retain_graph) - - loss, loss_detail = self.adversarial_training( - train_X, train_y, output, loss, loss_detail, grad_accumulation_steps - ) - - if (self.global_step + 1) % grad_accumulation_steps == 0: - skip_scheduler = False - if self.use_amp: - self.scaler.unscale_(self.optimizer) - if self.max_grad_norm is not None: - torch.nn.utils.clip_grad_norm_( - self.parameters(), self.max_grad_norm - ) - self.scaler.step(self.optimizer) - self.scaler.update() - skip_scheduler = self.scaler.get_scale() != scale_before_step - else: - if self.max_grad_norm is not None: - torch.nn.utils.clip_grad_norm_( - self.parameters(), self.max_grad_norm - ) - self.optimizer.step() - - self.optimizer.zero_grad() - if (self.scheduler is not None) and not skip_scheduler: - self.scheduler.step() - - # 添加log打印 - logs.update({"loss": loss.item()}) - logs_loss_detail = { - k: v.item() if isinstance(v, torch.Tensor) else v - for k, v in loss_detail.items() - } - logs.update(logs_loss_detail) - if self.global_step == 0: - self.callbacks[0].add_metrics( - list(logs_loss_detail.keys()), add_position=1 - ) - for metric in self.metrics: - tmp = metric_mapping(metric, output, train_y) - if tmp is not None: - logs[metric] = tmp - self.callback_fun("batch_end", logs) - - self.bti += 1 - self.global_step += 1 - self.callback_fun("epoch_end", logs) - callback_tmp = [ - callback_tmp - for callback_tmp in self.callbacks - if isinstance(callback_tmp, EarlyStopping) - ] - if callback_tmp and callback_tmp[0].stopped_epoch > 0: - break - self.callback_fun("train_end", logs) - - @torch.no_grad() - def predict(self, input_tensor_list, return_all=None): - self.eval() - if self.forward.__code__.co_argcount >= 3: - output = self.forward(*input_tensor_list) - else: - output = self.forward(input_tensor_list) - if return_all is None: - return output - elif ( - isinstance(output, (tuple, list)) - and isinstance(return_all, int) - and return_all < len(output) - ): - return output[return_all] - else: - raise ValueError("Return format error") - - def load_weights(self, load_path, strict=True, prefix=None): - state_dict = torch.load(load_path, map_location="cpu") - if prefix is None: - self.load_state_dict(state_dict, strict=strict) - else: - eval_str = ( - "self.variable_mapping()" - if prefix == "" - else f"self.{prefix}.variable_mapping()" - ) - mapping = {v: k for k, v in eval(eval_str).items()} - mapping = ( - mapping - if prefix == "" - else {k: f"{prefix}.{v}" for k, v in mapping.items()} - ) - state_dict_raw = {} - for k, v in state_dict.items(): - k = mapping.get(k, k) - state_dict_raw[k] = v - self.load_state_dict(state_dict_raw, strict=strict) - - def save_weights(self, save_path, prefix=None): - if prefix is None: - torch.save(self.state_dict(), save_path) - else: - eval_str = ( - "self.variable_mapping()" - if prefix == "" - else f"self.{prefix}.variable_mapping()" - ) - mapping = eval(eval_str) - mapping = ( - mapping - if prefix == "" - else {f"{prefix}.{k}": v for k, v in mapping.items()} - ) - state_dict_raw = {} - for k, v in self.state_dict().items(): - k = mapping.get(k, k) - state_dict_raw[k] = v - torch.save(state_dict_raw, save_path) - - -class BaseModelDP(BaseModel, nn.DataParallel): - - def __init__(self, *args, **kwargs): - nn.DataParallel.__init__(self, *args, **kwargs) - - -class BaseModelDDP(BaseModel, nn.parallel.DistributedDataParallel): - - def __init__(self, *args, master_rank=0, **kwargs): - self.master_rank = master_rank - nn.parallel.DistributedDataParallel.__init__(self, *args, **kwargs) - - -class BERT_BASE(BaseModel): - """模型基类""" - - def __init__( - self, - vocab_size, - hidden_size, - num_hidden_layers, - num_attention_heads, - intermediate_size, - hidden_act, - dropout_rate=None, - attention_probs_dropout_prob=None, - embedding_size=None, - attention_head_size=None, - attention_key_size=None, - initializer_range=0.02, - sequence_length=None, - keep_tokens=None, - compound_tokens=None, - residual_attention_scores=False, - ignore_invalid_weights=False, - keep_hidden_layers=None, - hierarchical_position=None, - **kwargs, - ): - super(BERT_BASE, self).__init__() - if keep_tokens is not None: - vocab_size = len(keep_tokens) - if compound_tokens is not None: - vocab_size += len(compound_tokens) - self.vocab_size = vocab_size - self.hidden_size = hidden_size - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - self.attention_head_size = ( - attention_head_size or self.hidden_size // self.num_attention_heads - ) - self.attention_key_size = attention_key_size or self.attention_head_size - self.intermediate_size = intermediate_size - self.dropout_rate = dropout_rate or 0 - self.attention_probs_dropout_prob = attention_probs_dropout_prob or 0 - self.hidden_act = hidden_act - self.embedding_size = embedding_size or hidden_size - self.initializer_range = initializer_range - self.sequence_length = sequence_length - self.keep_tokens = keep_tokens - self.compound_tokens = compound_tokens - self.attention_bias = None - self.position_bias = None - self.attention_scores = None - self.residual_attention_scores = residual_attention_scores - self.ignore_invalid_weights = ignore_invalid_weights - self.keep_hidden_layers = ( - set(range(num_hidden_layers)) - if keep_hidden_layers is None - else set(keep_hidden_layers) - ) - self.hierarchical_position = hierarchical_position - - def build( - self, - attention_caches=None, - layer_norm_cond=None, - layer_norm_cond_hidden_size=None, - layer_norm_cond_hidden_act=None, - additional_input_layers=None, - **kwargs, - ): - self.attention_caches = attention_caches or {} - self.output_all_encoded_layers = kwargs.get("output_all_encoded_layers", False) - - def forward(self, inputs): - # Embedding - outputs = self.apply_embeddings(inputs) - # Main - outputs = self.apply_main_layers(outputs) - # Final - outputs = self.apply_final_layers(outputs) - return outputs - - def init_model_weights(self, module): - if isinstance(module, (nn.Linear, nn.Embedding)) and ( - module.weight.requires_grad - ): - module.weight.data.normal_(mean=0.0, std=self.initializer_range) - elif isinstance(module, LayerNorm): - if ( - hasattr(module, "bias") and module.bias.requires_grad - ): - module.bias.data.zero_() - if hasattr(module, "weight") and module.weight.requires_grad: - module.weight.data.fill_(1.0) - if ( - isinstance(module, nn.Linear) - and (module.bias is not None) - and (module.bias.requires_grad) - ): - module.bias.data.zero_() - - def variable_mapping(self): - return {} - - def load_load_variable(self): - raise NotImplementedError - - def load_embeddings(self, embeddings): - if self.keep_tokens is not None: - embeddings = embeddings[self.keep_tokens] - - if self.compound_tokens is not None: - ext_embeddings = [] - for item in self.compound_tokens: - try: - ext_embeddings.append( - torch.mean(embeddings[item], 0) - * torch.ones_like(embeddings[item]) - ) - except IndexError: - ext_embeddings.append(torch.mean(embeddings, 0, keepdim=True)) - warnings.warn( - f"Initialize ext_embeddings from compound_tokens not in embedding index" - ) - embeddings = torch.cat([embeddings] + ext_embeddings, 0) - - return embeddings - - def load_pos_embeddings(self, embeddings): - if self.hierarchical_position is not None: - alpha = ( - 0.4 - if self.hierarchical_position is True - else self.hierarchical_position - ) - embeddings = embeddings - alpha * embeddings[:1] - embeddings = embeddings / (1 - alpha) - position_index = torch.arange(self.max_position)[:, None] - - embeddings_x = take_along_dim( - embeddings, - torch.div(position_index, embeddings.size(0), rounding_mode="trunc"), - dim=0, - ) - embeddings_y = take_along_dim( - embeddings, position_index % embeddings.size(0), dim=0 - ) - embeddings = alpha * embeddings_x + (1 - alpha) * embeddings_y - - return embeddings - - def load_weights_from_pytorch_checkpoint(self, checkpoint, mapping=None): - file_state_dict = torch.load(checkpoint, map_location="cpu") - mapping = mapping or self.variable_mapping() - parameters_set = set([i[0] for i in self.named_parameters()]) - - for layer_name in parameters_set: - if (layer_name in file_state_dict) and (layer_name not in mapping): - mapping.update({layer_name: layer_name}) - - state_dict_new = {} - for new_key, old_key in mapping.items(): - if new_key not in self.state_dict(): - continue - elif old_key in file_state_dict: - state_dict_new[new_key] = self.load_variable(file_state_dict, old_key) - elif (old_key not in file_state_dict) and (not self.ignore_invalid_weights): - print(f"[WARNIMG] {old_key} not found in pretrain models") - if new_key in parameters_set: - parameters_set.remove(new_key) - - if not self.ignore_invalid_weights: - for key in parameters_set: - print(f"[WARNIMG] Parameter {key} not loaded from pretrain models") - del file_state_dict - - self.load_state_dict(state_dict_new, strict=False) - - - def apply_embeddings(self, inputs): - raise NotImplementedError - - def apply_main_layers(self, inputs): - raise NotImplementedError - - def apply_final_layers(self, inputs): - raise NotImplementedError - - def apply_on_layer_begin(self, l_i, inputs): - - return inputs - - def apply_on_layer_end(self, l_i, inputs): - - return inputs - - def compute_attention_bias(self, inputs=None): - - return self.attention_bias - - def compute_position_bias(self, inputs=None): - - return self.position_bias - - def set_outputs(self, outputs): - - if not isinstance(outputs, list): - outputs = [outputs] - - outputs = outputs[:] - self.outputs = outputs - if len(outputs) > 1: - self.output = outputs - else: - self.output = outputs[0] - - -class LM_Mask(object): - - def compute_attention_bias(self, inputs=None): - seq_len = inputs[0].shape[1] - attention_bias = torch.tril( - torch.ones(seq_len, seq_len, dtype=torch.long, device=inputs[0].device), - diagonal=0, - ) - self.attention_bias = attention_bias.unsqueeze(0).unsqueeze(1) - return self.attention_bias - - -def extend_with_language_model(InputModel): - - class LanguageModel(LM_Mask, InputModel): - - def __init__(self, *args, **kwargs): - kwargs["with_mlm"] = kwargs.get("with_mlm") or True - super(LanguageModel, self).__init__(*args, **kwargs) - - return LanguageModel - - -class UniLM_Mask(object): - def compute_attention_bias(self, inputs=None): - segment_ids = inputs[1] - attention_bias = torch.cumsum(segment_ids, dim=1) - attention_bias = (attention_bias.unsqueeze(1)) <= (attention_bias.unsqueeze(2)) - self.attention_bias = attention_bias.unsqueeze(1).long() - - return self.attention_bias - - -def extend_with_unified_language_model(InputModel): - - class UnifiedLanguageModel(UniLM_Mask, InputModel): - - def __init__(self, *args, **kwargs): - kwargs["with_mlm"] = kwargs.get("with_mlm") or True - super(UnifiedLanguageModel, self).__init__(*args, **kwargs) - - return UnifiedLanguageModel - - -class BERT(BERT_BASE): - def __init__( - self, - max_position, - segment_vocab_size=2, - with_pool=False, - with_nsp=False, - with_mlm=False, - custom_position_ids=False, - custom_attention_mask=False, - shared_segment_embeddings=False, - layer_norm_cond=None, - layer_add_embs=None, - is_dropout=False, - token_pad_ids=0, - **kwargs, - ): - super(BERT, self).__init__(**kwargs) - self.max_position = max_position - self.segment_vocab_size = segment_vocab_size - self.with_pool = with_pool - self.with_nsp = with_nsp - self.with_mlm = with_mlm - self.custom_position_ids = custom_position_ids - self.custom_attention_mask = custom_attention_mask - self.shared_segment_embeddings = shared_segment_embeddings - self.is_dropout = is_dropout - self.token_pad_ids = token_pad_ids - if self.with_nsp and not self.with_pool: - self.with_pool = True - self.layer_norm_conds = layer_norm_cond - self.layer_add_embs = layer_add_embs - self.conditional_size = ( - layer_norm_cond.weight.size(1) if layer_norm_cond is not None else None - ) - self.embeddings = BertEmbeddings( - self.vocab_size, - self.embedding_size, - self.hidden_size, - self.max_position, - self.segment_vocab_size, - self.shared_segment_embeddings, - self.dropout_rate, - self.conditional_size, - **get_kw(BertEmbeddings, kwargs), - ) - kwargs["max_position"] = self.max_position - layer = BertLayer( - self.hidden_size, - self.num_attention_heads, - self.dropout_rate, - self.attention_probs_dropout_prob, - self.intermediate_size, - self.hidden_act, - is_dropout=self.is_dropout, - conditional_size=self.conditional_size, - **get_kw(BertLayer, kwargs), - ) - self.encoderLayer = nn.ModuleList( - [ - copy.deepcopy(layer) - if layer_id in self.keep_hidden_layers - else Identity() - for layer_id in range(self.num_hidden_layers) - ] - ) - if self.with_pool: - - self.pooler = nn.Linear(self.hidden_size, self.hidden_size) - self.pooler_activation = ( - nn.Tanh() if self.with_pool is True else get_activation(self.with_pool) - ) - if self.with_nsp: - - self.nsp = nn.Linear(self.hidden_size, 2) - else: - self.pooler = None - self.pooler_activation = None - if self.with_mlm: - self.mlmDense = nn.Linear(self.hidden_size, self.hidden_size) - self.transform_act_fn = get_activation(self.hidden_act) - self.mlmLayerNorm = LayerNorm( - self.hidden_size, eps=1e-12, conditional_size=self.conditional_size - ) - self.mlmDecoder = nn.Linear(self.hidden_size, self.vocab_size, bias=False) - if kwargs.get("tie_emb_prj_weight") is True: - self.mlmDecoder.weight = self.embeddings.word_embeddings.weight - self.mlmBias = nn.Parameter(torch.zeros(self.vocab_size)) - self.mlmDecoder.bias = self.mlmBias - - - def apply_embeddings(self, inputs): - token_ids = inputs[0] - index_ = 1 - if self.segment_vocab_size > 0: - segment_ids = inputs[index_] - index_ += 1 - else: - segment_ids = None - - if self.custom_position_ids: - position_ids = inputs[index_] - index_ += 1 - else: - position_ids = None - - if self.custom_attention_mask: - attention_mask = inputs[index_].long().unsqueeze(1).unsqueeze(2) - index_ += 1 - elif (not token_ids.requires_grad) and ( - token_ids.dtype in {torch.long, torch.int} - ): - attention_mask = ( - (token_ids != self.token_pad_ids).long().unsqueeze(1).unsqueeze(2) - ) - if self.token_pad_ids < 0: - token_ids = token_ids * attention_mask[:, 0, 0, :] - else: - attention_mask = self.attention_mask_cache - self.attention_mask_cache = attention_mask - - self.compute_attention_bias([token_ids, segment_ids]) - if self.attention_bias is not None: - attention_mask = attention_mask * self.attention_bias - - try: - attention_mask = attention_mask.to( - dtype=next(self.parameters()).dtype - ) - except StopIteration: - attention_mask = attention_mask.to(dtype=torch.float32) - - if self.layer_norm_conds is None: - conditional_emb = None - else: - conditional_emb = self.layer_norm_conds(inputs[index_]) - index_ += 1 - - - if isinstance(self.layer_add_embs, nn.Module): - additional_embs = [self.layer_add_embs(inputs[index_])] - index_ += 1 - elif isinstance(self.layer_add_embs, (tuple, list)): - additional_embs = [] - for layer in self.layer_add_embs: - assert isinstance( - layer, nn.Module - ), "Layer_add_embs element should be nn.Module" - additional_embs.append(layer(inputs[index_])) - index_ += 1 - else: - additional_embs = None - - - hidden_states = self.embeddings( - token_ids, segment_ids, conditional_emb, additional_embs - ) - return [hidden_states, attention_mask, conditional_emb] + inputs[index_:] - - def apply_main_layers(self, inputs): - hidden_states, attention_mask, conditional_emb = inputs[:3] - if len(inputs[3:]) >= 2: - encoder_hidden_state, encoder_attention_mask = inputs[3], inputs[4] - else: - encoder_hidden_state, encoder_attention_mask = None, None - - encoded_layers = [hidden_states] - layer_inputs = [ - hidden_states, - attention_mask, - conditional_emb, - encoder_hidden_state, - encoder_attention_mask, - ] - for l_i, layer_module in enumerate(self.encoderLayer): - layer_inputs = self.apply_on_layer_begin(l_i, layer_inputs) - hidden_states = layer_module(*layer_inputs) - layer_inputs[0] = hidden_states - layer_inputs = self.apply_on_layer_end(l_i, layer_inputs) - - if self.output_all_encoded_layers: - encoded_layers.append(hidden_states) - if not self.output_all_encoded_layers: - encoded_layers.append(hidden_states) - return [encoded_layers, conditional_emb] - - def apply_final_layers(self, inputs): - encoded_layers, conditional_emb = inputs - sequence_output = encoded_layers[-1] - - if not self.output_all_encoded_layers: - encoded_layers = encoded_layers[-1] - - - if self.with_pool: - pooled_output = self.pooler_activation(self.pooler(sequence_output[:, 0])) - else: - pooled_output = None - - if self.with_pool and self.with_nsp: - nsp_scores = self.nsp(pooled_output) - else: - nsp_scores = None - - if self.with_mlm: - mlm_hidden_state = self.mlmDense(sequence_output) - mlm_hidden_state = self.transform_act_fn(mlm_hidden_state) - mlm_hidden_state = self.mlmLayerNorm((mlm_hidden_state, conditional_emb)) - mlm_scores = self.mlmDecoder(mlm_hidden_state) - mlm_activation = get_activation( - "linear" if self.with_mlm is True else self.with_mlm - ) - mlm_scores = mlm_activation(mlm_scores) - else: - mlm_scores = None - - outputs = [ - value - for value in [encoded_layers, pooled_output, mlm_scores, nsp_scores] - if value is not None - ] - return outputs if len(outputs) > 1 else outputs[0] - - def load_variable(self, state_dict, name, prefix="bert"): - variable = state_dict[name] - if name in { - f"{prefix}.embeddings.word_embeddings.weight", - "cls.predictions.bias", - "cls.predictions.decoder.weight", - "cls.predictions.decoder.bias", - }: - return self.load_embeddings(variable) - elif name == f"{prefix}.embeddings.position_embeddings.weight": - return self.load_pos_embeddings(variable) - elif name == "cls.seq_relationship.weight": - return variable.T - else: - return variable - - def variable_mapping(self, prefix="bert"): - mapping = { - "embeddings.word_embeddings.weight": f"{prefix}.embeddings.word_embeddings.weight", - "embeddings.position_embeddings.weight": f"{prefix}.embeddings.position_embeddings.weight", - "embeddings.segment_embeddings.weight": f"{prefix}.embeddings.token_type_embeddings.weight", - "embeddings.layerNorm.weight": f"{prefix}.embeddings.LayerNorm.weight", - "embeddings.layerNorm.bias": f"{prefix}.embeddings.LayerNorm.bias", - "pooler.weight": f"{prefix}.pooler.dense.weight", - "pooler.bias": f"{prefix}.pooler.dense.bias", - "nsp.weight": "cls.seq_relationship.weight", - "nsp.bias": "cls.seq_relationship.bias", - "mlmDense.weight": "cls.predictions.transform.dense.weight", - "mlmDense.bias": "cls.predictions.transform.dense.bias", - "mlmLayerNorm.weight": "cls.predictions.transform.LayerNorm.weight", - "mlmLayerNorm.bias": "cls.predictions.transform.LayerNorm.bias", - "mlmBias": "cls.predictions.bias", - "mlmDecoder.weight": "cls.predictions.decoder.weight", - "mlmDecoder.bias": "cls.predictions.decoder.bias", - } - for i in range(self.num_hidden_layers): - prefix_i = f"{prefix}.encoder.layer.%d." % i - mapping.update( - { - f"encoderLayer.{i}.multiHeadAttention.q.weight": prefix_i - + "attention.self.query.weight", - f"encoderLayer.{i}.multiHeadAttention.q.bias": prefix_i - + "attention.self.query.bias", - f"encoderLayer.{i}.multiHeadAttention.k.weight": prefix_i - + "attention.self.key.weight", - f"encoderLayer.{i}.multiHeadAttention.k.bias": prefix_i - + "attention.self.key.bias", - f"encoderLayer.{i}.multiHeadAttention.v.weight": prefix_i - + "attention.self.value.weight", - f"encoderLayer.{i}.multiHeadAttention.v.bias": prefix_i - + "attention.self.value.bias", - f"encoderLayer.{i}.multiHeadAttention.o.weight": prefix_i - + "attention.output.dense.weight", - f"encoderLayer.{i}.multiHeadAttention.o.bias": prefix_i - + "attention.output.dense.bias", - f"encoderLayer.{i}.layerNorm1.weight": prefix_i - + "attention.output.LayerNorm.weight", - f"encoderLayer.{i}.layerNorm1.bias": prefix_i - + "attention.output.LayerNorm.bias", - f"encoderLayer.{i}.feedForward.intermediateDense.weight": prefix_i - + "intermediate.dense.weight", - f"encoderLayer.{i}.feedForward.intermediateDense.bias": prefix_i - + "intermediate.dense.bias", - f"encoderLayer.{i}.feedForward.outputDense.weight": prefix_i - + "output.dense.weight", - f"encoderLayer.{i}.feedForward.outputDense.bias": prefix_i - + "output.dense.bias", - f"encoderLayer.{i}.layerNorm2.weight": prefix_i - + "output.LayerNorm.weight", - f"encoderLayer.{i}.layerNorm2.bias": prefix_i - + "output.LayerNorm.bias", - } - ) - - return mapping - - -class ALBERT(BERT): - def __init__(self, *args, **kwargs): - super(ALBERT, self).__init__(*args, **kwargs) - self.encoderLayer = nn.ModuleList([self.encoderLayer[0]]) - - def apply_main_layers(self, inputs): - hidden_states, attention_mask, conditional_emb = inputs[:3] - if len(inputs[3:]) >= 2: - encoder_hidden_state, encoder_attention_mask = inputs[3], inputs[4] - else: - encoder_hidden_state, encoder_attention_mask = None, None - - encoded_layers = [hidden_states] - layer_inputs = [ - hidden_states, - attention_mask, - conditional_emb, - encoder_hidden_state, - encoder_attention_mask, - ] - for l_i in range(self.num_hidden_layers): - layer_inputs = self.apply_on_layer_begin(l_i, layer_inputs) - hidden_states = self.encoderLayer[0](*layer_inputs) - layer_inputs[0] = hidden_states - layer_inputs = self.apply_on_layer_end(l_i, layer_inputs) - - if self.output_all_encoded_layers: - encoded_layers.append(hidden_states) - if not self.output_all_encoded_layers: - encoded_layers.append(hidden_states) - return [encoded_layers, conditional_emb] - - def variable_mapping(self, prefix="albert"): - mapping = { - "embeddings.word_embeddings.weight": f"{prefix}.embeddings.word_embeddings.weight", - "embeddings.position_embeddings.weight": f"{prefix}.embeddings.position_embeddings.weight", - "embeddings.segment_embeddings.weight": f"{prefix}.embeddings.token_type_embeddings.weight", - "embeddings.layerNorm.weight": f"{prefix}.embeddings.LayerNorm.weight", - "embeddings.layerNorm.bias": f"{prefix}.embeddings.LayerNorm.bias", - "embeddings.embedding_hidden_mapping_in.weight": f"{prefix}.encoder.embedding_hidden_mapping_in.weight", - "embeddings.embedding_hidden_mapping_in.bias": f"{prefix}.encoder.embedding_hidden_mapping_in.bias", - "pooler.weight": f"{prefix}.pooler.weight", - "pooler.bias": f"{prefix}.pooler.bias", - "nsp.weight": "sop_classifier.classifier.weight", - "nsp.bias": "sop_classifier.classifier.bias", - "mlmDense.weight": "predictions.dense.weight", - "mlmDense.bias": "predictions.dense.bias", - "mlmLayerNorm.weight": "predictions.LayerNorm.weight", - "mlmLayerNorm.bias": "predictions.LayerNorm.bias", - "mlmBias": "predictions.bias", - "mlmDecoder.weight": "predictions.decoder.weight", - "mlmDecoder.bias": "predictions.decoder.bias", - } - i = 0 - prefix_i = f"{prefix}.encoder.albert_layer_groups.{i}.albert_layers.{i}." - mapping.update( - { - f"encoderLayer.{i}.multiHeadAttention.q.weight": prefix_i - + "attention.query.weight", - f"encoderLayer.{i}.multiHeadAttention.q.bias": prefix_i - + "attention.query.bias", - f"encoderLayer.{i}.multiHeadAttention.k.weight": prefix_i - + "attention.key.weight", - f"encoderLayer.{i}.multiHeadAttention.k.bias": prefix_i - + "attention.key.bias", - f"encoderLayer.{i}.multiHeadAttention.v.weight": prefix_i - + "attention.value.weight", - f"encoderLayer.{i}.multiHeadAttention.v.bias": prefix_i - + "attention.value.bias", - f"encoderLayer.{i}.multiHeadAttention.o.weight": prefix_i - + "attention.dense.weight", - f"encoderLayer.{i}.multiHeadAttention.o.bias": prefix_i - + "attention.dense.bias", - f"encoderLayer.{i}.layerNorm1.weight": prefix_i - + "attention.LayerNorm.weight", - f"encoderLayer.{i}.layerNorm1.bias": prefix_i - + "attention.LayerNorm.bias", - f"encoderLayer.{i}.feedForward.intermediateDense.weight": prefix_i - + "ffn.weight", - f"encoderLayer.{i}.feedForward.intermediateDense.bias": prefix_i - + "ffn.bias", - f"encoderLayer.{i}.feedForward.outputDense.weight": prefix_i - + "ffn_output.weight", - f"encoderLayer.{i}.feedForward.outputDense.bias": prefix_i - + "ffn_output.bias", - f"encoderLayer.{i}.layerNorm2.weight": prefix_i - + "full_layer_layer_norm.weight", - f"encoderLayer.{i}.layerNorm2.bias": prefix_i - + "full_layer_layer_norm.bias", - } - ) - - return mapping - - def load_variable(self, state_dict, name): - - variable = state_dict[name] - if name in { - "albert.embeddings.word_embeddings.weight", - "predictions.bias", - "predictions.decoder.weight", - "predictions.decoder.bias", - }: - return self.load_embeddings(variable) - elif name == "albert.embeddings.position_embeddings.weight": - return self.load_pos_embeddings(variable) - elif name == "sop_classifier.classifier.weight": - return variable.T - else: - return variable - - -class ALBERT_Unshared(ALBERT): - def __init__(self, *args, **kwargs): - super(ALBERT_Unshared).__init__(*args, **kwargs) - self.encoderLayer = nn.ModuleList( - [copy.deepcopy(self.encoderLayer[0]) for _ in range(self.num_hidden_layers)] - ) - - def apply_main_layers(self, inputs): - - hidden_states, attention_mask, conditional_emb = inputs - if len(inputs[3:]) >= 2: - encoder_hidden_state, encoder_attention_mask = inputs[3], inputs[4] - else: - encoder_hidden_state, encoder_attention_mask = None, None - - encoded_layers = [hidden_states] # 添加embedding的输出 - layer_inputs = [ - hidden_states, - attention_mask, - conditional_emb, - encoder_hidden_state, - encoder_attention_mask, - ] - for i in range(self.num_hidden_layers): - layer_inputs = self.apply_on_layer_begin(i, layer_inputs) - hidden_states = self.encoderLayer[i](*layer_inputs) - layer_inputs[0] = hidden_states - layer_inputs = self.apply_on_layer_end(i, layer_inputs) - - if self.output_all_encoded_layers: - encoded_layers.append(hidden_states) - if not self.output_all_encoded_layers: - encoded_layers.append(hidden_states) - return [encoded_layers, conditional_emb] - - -class NEZHA(BERT): - def __init__(self, *args, **kwargs): - - kwargs.update( - { - "p_bias": "typical_relative", - "max_relative_position": kwargs.get("max_relative_position", 64), - } - ) - super(NEZHA, self).__init__(*args, **kwargs) - - -class RoFormer(BERT): - def __init__(self, *args, **kwargs): - kwargs.update({"p_bias": "rotary"}) - super(RoFormer, self).__init__(*args, **kwargs) - - def load_variable(self, state_dict, name, prefix="roformer"): - return super().load_variable(state_dict, name, prefix) - - def variable_mapping(self, prefix="roformer"): - mapping = super().variable_mapping(prefix) - del mapping["embeddings.position_embeddings.weight"] - return mapping - - -class RoFormerV2(RoFormer): - @delete_arguments("with_pool", "with_nsp") - def __init__(self, *args, **kwargs): - kwargs.update( - {"p_bias": "rotary", "weight": False, "bias": False, "norm_mode": "rmsnorm"} - ) - super(RoFormerV2, self).__init__(*args, **kwargs) - if self.with_mlm: - del self.mlmLayerNorm - del self.mlmBias - del self.mlmDense - self.mlmDecoder.register_parameter("bias", None) - - def variable_mapping(self, prefix="roformer"): - mapping = super().variable_mapping(prefix) - mapping_new = {} - for k, v in mapping.items(): - if (not re.search("bias|layernorm", k.lower())) and ( - not re.search("bias|layernorm", v.lower()) - ): - mapping_new[k] = v - return mapping_new - - def apply_final_layers(self, inputs): - encoded_layers, conditional_emb = inputs - sequence_output = encoded_layers[-1] - if not self.output_all_encoded_layers: - encoded_layers = encoded_layers[-1] - - if self.with_mlm: - mlm_scores = self.mlmDecoder(sequence_output) - else: - mlm_scores = None - - outputs = [value for value in [encoded_layers, mlm_scores] if value is not None] - return outputs if len(outputs) > 1 else outputs[0] - - -class GAU_alpha(RoFormerV2): - def __init__(self, *args, **kwargs): - kwargs.update( - { - "p_bias": "rotary", - "weight": False, - "bias": False, - "norm_mode": "rmsnorm", - "normalization": "softmax_plus", - } - ) - super().__init__(*args, **kwargs) - - layer = self.GAU_Layer(**kwargs) - self.encoderLayer = nn.ModuleList( - [ - copy.deepcopy(layer) - if layer_id in self.keep_hidden_layers - else Identity() - for layer_id in range(self.num_hidden_layers) - ] - ) - - def load_variable(self, state_dict, name, prefix=""): - variable = state_dict[name] - return ( - self.load_embeddings(variable) - if name in {"embeddings.word_embeddings.weight", "mlmDecoder.weight"} - else variable - ) - - def variable_mapping(self, prefix=""): - return {k: k for k, _ in self.named_parameters()} - - class GAU_Layer(nn.Module): - def __init__(self, *args, **kwargs): - super().__init__() - self.gau = GatedAttentionUnit(**kwargs) - self.dropout1 = nn.Dropout(kwargs.get("dropout_rate")) - self.layerNorm1 = LayerNorm(**kwargs) - - def forward( - self, - hidden_states, - attention_mask, - conditional_emb=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - ): - gau_hidden_states = self.gau(hidden_states, attention_mask) - hidden_states = hidden_states + self.dropout1(gau_hidden_states) - hidden_states = self.layerNorm1((hidden_states, conditional_emb)) - return hidden_states - - -class ELECTRA(BERT): - @insert_arguments(with_discriminator=False) - @delete_arguments("with_pool", "with_mlm", "with_nsp") - def __init__(self, max_position, **kwargs): - super(ELECTRA, self).__init__(max_position, **kwargs) - if self.with_discriminator: - self.dense = nn.Linear(self.hidden_size, self.hidden_size) - self.dense_act = get_activation(self.hidden_act) - self.dense_prediction = nn.Linear(self.hidden_size, 1) - self.dense_prediction_act = ( - get_activation("sigmoid") - if self.with_discriminator is True - else get_activation(self.with_discriminator) - ) - - def apply_final_layers(self, inputs): - hidden_states = super().apply_final_layers(inputs) - if self.with_discriminator: - logit = self.dense_act(self.dense(hidden_states)) - return [ - hidden_states, - self.dense_prediction_act(self.dense_prediction(logit)), - ] - else: - return hidden_states - - def load_variable(self, state_dict, name): - return super().load_variable(state_dict, name, prefix="electra") - - def variable_mapping(self): - mapping = super(ELECTRA, self).variable_mapping(prefix="electra") - mapping.update( - { - "dense.weight": "discriminator_predictions.dense.weight", - "dense.bias": "discriminator_predictions.dense.bias", - "dense_prediction.weight": "discriminator_predictions.dense_prediction.weight", - "dense_prediction.bias": "discriminator_predictions.dense_prediction.bias", - } - ) - for del_key in [ - "pooler.weight", - "pooler.bias", - "nsp.weight", - "nsp.bias", - "mlmDense.weight", - "mlmDense.bias", - "mlmLayerNorm.weight", - "mlmLayerNorm.bias", - "mlmBias", - "mlmDecoder.weight", - "mlmDecoder.bias", - ]: - del mapping[del_key] - - return mapping - - -class Encoder(BERT): - def __init__(self, *args, **kwargs): - kwargs["vocab_size"] = kwargs.get("src_vocab_size", kwargs["vocab_size"]) - super().__init__(*args, **kwargs) - self.encoder_attention_mask = None - - def forward(self, inputs): - # Embedding - outputs = self.apply_embeddings(inputs) - encoder_attention_mask = [outputs[1]] - # Main - outputs = self.apply_main_layers(outputs) - # Final - outputs = self.apply_final_layers(outputs) - return ( - [outputs] if isinstance(outputs, torch.Tensor) else outputs - ) + encoder_attention_mask - - -class Decoder(LM_Mask, BERT): - @delete_arguments("with_pool", "with_mlm", "with_nsp") - def __init__(self, *args, with_lm=True, tie_emb_prj_weight=True, **kwargs): - kwargs["vocab_size"] = kwargs.get("tgt_vocab_size", kwargs["vocab_size"]) - kwargs["is_decoder"] = True - super().__init__(*args, **kwargs) - self.decoderLayer = self.encoderLayer - del self.encoderLayer - self.with_lm = with_lm - - if self.with_lm: - self.final_dense = nn.Linear(self.hidden_size, self.vocab_size, bias=False) - if tie_emb_prj_weight: - self.final_dense.weight = self.embeddings.word_embeddings.weight - self.x_logit_scale = self.hidden_size**-0.5 - else: - self.x_logit_scale = 1.0 - - def apply_main_layers(self, inputs): - ( - hidden_states, - attention_mask, - conditional_emb, - encoder_hidden_state, - encoder_attention_mask, - ) = inputs[:5] - decoded_layers = [hidden_states] - layer_inputs = [ - hidden_states, - attention_mask, - conditional_emb, - encoder_hidden_state, - encoder_attention_mask, - ] - for i, layer_module in enumerate(self.decoderLayer): - layer_inputs = self.apply_on_layer_begin(i, layer_inputs) - hidden_states = layer_module(*layer_inputs) - layer_inputs[0] = hidden_states - layer_inputs = self.apply_on_layer_end(i, layer_inputs) - - if self.output_all_encoded_layers: - decoded_layers.append(hidden_states) - if not self.output_all_encoded_layers: - decoded_layers.append(hidden_states) - return [decoded_layers, conditional_emb] - - def apply_final_layers(self, inputs): - outputs = [] - hidden_states = super().apply_final_layers( - inputs - ) - outputs.append(hidden_states) - if self.with_lm: - logits = ( - self.final_dense(hidden_states) * self.x_logit_scale - ) - activation = get_activation( - "linear" if self.with_lm is True else self.with_lm - ) - logits = activation(logits) - outputs.append(logits) - return outputs - - def variable_mapping(self, prefix="bert"): - raw_mapping = super().variable_mapping(prefix) - mapping = {} - for k, v in raw_mapping.items(): - mapping[k.replace("encoderLayer", "decoderLayer")] = v - return mapping - - -class Transformer(BERT_BASE): - """encoder-decoder结构""" - - @delete_arguments("with_pool", "with_mlm", "with_nsp") - def __init__(self, *args, tie_emb_src_tgt_weight=False, **kwargs): - super(Transformer, self).__init__(*args, **kwargs) - - # encoder - self.encoder = Encoder(*args, **kwargs) - self.encoder.build(**kwargs) - - # decoder - self.decoder = Decoder(*args, **kwargs) - self.decoder.build(**kwargs) - - if tie_emb_src_tgt_weight: - assert ( - self.encoder.vocab_size == self.decoder.vocab_size - ), "To share word embedding, the vocab size of src/tgt shall be the same." - self.encoder.embeddings.word_embeddings.weight = ( - self.decoder.embeddings.word_embeddings.weight - ) - - def forward(self, inputs): - encoder_input, decoder_input = inputs[:2] - - # encoder - # encoder_emb = self.encoder.apply_embeddings(encoder_input) - # encode_outputs = self.encoder.apply_main_layers(encoder_emb) - # encoder_hidden_state = self.encoder.apply_final_layers(encode_outputs) - # encoder_attention_mask = encoder_emb[1] - encoder_hidden_state, encoder_attention_mask = self.encoder(encoder_input) - - # decoder - # decoder_emb = self.decoder.apply_embeddings(decoder_input) - # decoder_outputs = self.decoder.apply_main_layers([*decoder_emb, encoder_hidden_state, encoder_attention_mask]) - # decoder_outputs = self.decoder.apply_final_layers(decoder_outputs) # [hidden_states, logits] - decoder_outputs = self.decoder( - decoder_input + [encoder_hidden_state, encoder_attention_mask] - ) - return [ - encoder_hidden_state - ] + decoder_outputs - - -class BART(Transformer): - """encoder-decoder结构""" - - def __init__(self, *args, tie_emb_src_tgt_weight=True, **kwargs): - super(BART, self).__init__( - *args, tie_emb_src_tgt_weight=tie_emb_src_tgt_weight, **kwargs - ) - self.tie_emb_src_tgt_weight = tie_emb_src_tgt_weight - - def load_variable(self, state_dict, name, prefix=""): - variable = state_dict[name] - if name in { - "shared.weight", - "encoder.embed_tokens.weight", - "decoder.embed_tokens.weight", - }: - return self.load_embeddings(variable) - elif name in { - "encoder.embed_positions.weight", - "decoder.embed_positions.weight", - }: - return self.load_pos_embeddings(variable) - else: - return variable - - def variable_mapping(self, prefix=""): - mapping = { - "encoder.embeddings.word_embeddings.weight": "shared.weight" - if self.tie_emb_src_tgt_weight - else "encoder.embed_tokens.weight", - "encoder.embeddings.position_embeddings.weight": "encoder.embed_positions.weight", - "encoder.embeddings.layerNorm.weight": "encoder.layernorm_embedding.weight", - "encoder.embeddings.layerNorm.bias": "encoder.layernorm_embedding.bias", - "decoder.embeddings.word_embeddings.weight": "shared.weight" - if self.tie_emb_src_tgt_weight - else "decoder.embed_tokens.weight", - "decoder.embeddings.position_embeddings.weight": "decoder.embed_positions.weight", - "decoder.embeddings.layerNorm.weight": "decoder.layernorm_embedding.weight", - "decoder.embeddings.layerNorm.bias": "decoder.layernorm_embedding.bias", - } - for i in range(self.num_hidden_layers): - mapping.update( - { - f"encoder.encoderLayer.{i}.multiHeadAttention.q.weight": f"encoder.layers.{i}.self_attn.q_proj.weight", - f"encoder.encoderLayer.{i}.multiHeadAttention.q.bias": f"encoder.layers.{i}.self_attn.q_proj.bias", - f"encoder.encoderLayer.{i}.multiHeadAttention.k.weight": f"encoder.layers.{i}.self_attn.k_proj.weight", - f"encoder.encoderLayer.{i}.multiHeadAttention.k.bias": f"encoder.layers.{i}.self_attn.k_proj.bias", - f"encoder.encoderLayer.{i}.multiHeadAttention.v.weight": f"encoder.layers.{i}.self_attn.v_proj.weight", - f"encoder.encoderLayer.{i}.multiHeadAttention.v.bias": f"encoder.layers.{i}.self_attn.v_proj.bias", - f"encoder.encoderLayer.{i}.multiHeadAttention.o.weight": f"encoder.layers.{i}.self_attn.out_proj.weight", - f"encoder.encoderLayer.{i}.multiHeadAttention.o.bias": f"encoder.layers.{i}.self_attn.out_proj.bias", - f"encoder.encoderLayer.{i}.layerNorm1.weight": f"encoder.layers.{i}.self_attn_layer_norm.weight", - f"encoder.encoderLayer.{i}.layerNorm1.bias": f"encoder.layers.{i}.self_attn_layer_norm.bias", - f"encoder.encoderLayer.{i}.feedForward.intermediateDense.weight": f"encoder.layers.{i}.fc1.weight", - f"encoder.encoderLayer.{i}.feedForward.intermediateDense.bias": f"encoder.layers.{i}.fc1.bias", - f"encoder.encoderLayer.{i}.feedForward.outputDense.weight": f"encoder.layers.{i}.fc2.weight", - f"encoder.encoderLayer.{i}.feedForward.outputDense.bias": f"encoder.layers.{i}.fc2.bias", - f"encoder.encoderLayer.{i}.layerNorm2.weight": f"encoder.layers.{i}.final_layer_norm.weight", - f"encoder.encoderLayer.{i}.layerNorm2.bias": f"encoder.layers.{i}.final_layer_norm.bias", - f"decoder.decoderLayer.{i}.multiHeadAttention.q.weight": f"decoder.layers.{i}.self_attn.q_proj.weight", - f"decoder.decoderLayer.{i}.multiHeadAttention.q.bias": f"decoder.layers.{i}.self_attn.q_proj.bias", - f"decoder.decoderLayer.{i}.multiHeadAttention.k.weight": f"decoder.layers.{i}.self_attn.k_proj.weight", - f"decoder.decoderLayer.{i}.multiHeadAttention.k.bias": f"decoder.layers.{i}.self_attn.k_proj.bias", - f"decoder.decoderLayer.{i}.multiHeadAttention.v.weight": f"decoder.layers.{i}.self_attn.v_proj.weight", - f"decoder.decoderLayer.{i}.multiHeadAttention.v.bias": f"decoder.layers.{i}.self_attn.v_proj.bias", - f"decoder.decoderLayer.{i}.multiHeadAttention.o.weight": f"decoder.layers.{i}.self_attn.out_proj.weight", - f"decoder.decoderLayer.{i}.multiHeadAttention.o.bias": f"decoder.layers.{i}.self_attn.out_proj.bias", - f"decoder.decoderLayer.{i}.layerNorm1.weight": f"decoder.layers.{i}.self_attn_layer_norm.weight", - f"decoder.decoderLayer.{i}.layerNorm1.bias": f"decoder.layers.{i}.self_attn_layer_norm.bias", - f"decoder.decoderLayer.{i}.crossAttention.q.weight": f"decoder.layers.{i}.encoder_attn.q_proj.weight", - f"decoder.decoderLayer.{i}.crossAttention.q.bias": f"decoder.layers.{i}.encoder_attn.q_proj.bias", - f"decoder.decoderLayer.{i}.crossAttention.k.weight": f"decoder.layers.{i}.encoder_attn.k_proj.weight", - f"decoder.decoderLayer.{i}.crossAttention.k.bias": f"decoder.layers.{i}.encoder_attn.k_proj.bias", - f"decoder.decoderLayer.{i}.crossAttention.v.weight": f"decoder.layers.{i}.encoder_attn.v_proj.weight", - f"decoder.decoderLayer.{i}.crossAttention.v.bias": f"decoder.layers.{i}.encoder_attn.v_proj.bias", - f"decoder.decoderLayer.{i}.crossAttention.o.weight": f"decoder.layers.{i}.encoder_attn.out_proj.weight", - f"decoder.decoderLayer.{i}.crossAttention.o.bias": f"decoder.layers.{i}.encoder_attn.out_proj.bias", - f"decoder.decoderLayer.{i}.layerNorm3.weight": f"decoder.layers.{i}.encoder_attn_layer_norm.weight", - f"decoder.decoderLayer.{i}.layerNorm3.bias": f"decoder.layers.{i}.encoder_attn_layer_norm.bias", - f"decoder.decoderLayer.{i}.feedForward.intermediateDense.weight": f"decoder.layers.{i}.fc1.weight", - f"decoder.decoderLayer.{i}.feedForward.intermediateDense.bias": f"decoder.layers.{i}.fc1.bias", - f"decoder.decoderLayer.{i}.feedForward.outputDense.weight": f"decoder.layers.{i}.fc2.weight", - f"decoder.decoderLayer.{i}.feedForward.outputDense.bias": f"decoder.layers.{i}.fc2.bias", - f"decoder.decoderLayer.{i}.layerNorm2.weight": f"decoder.layers.{i}.final_layer_norm.weight", - f"decoder.decoderLayer.{i}.layerNorm2.bias": f"decoder.layers.{i}.final_layer_norm.bias", - } - ) - - return mapping - - -class T5_Encoder(Encoder): - @insert_arguments(version="t5.1.0") - def __init__(self, *args, **kwargs): - kwargs.update( - { - "p_bias": "t5_relative", - "relative_attention_num_buckets": kwargs.get( - "relative_attention_num_buckets" - ), - "version": self.version, - "bias": False, - "norm_mode": "rmsnorm", - } - ) - super().__init__(*args, **kwargs) - del self.embeddings.layerNorm - - layer = T5Layer( - self.hidden_size, - self.num_attention_heads, - self.dropout_rate, - self.attention_probs_dropout_prob, - self.intermediate_size, - self.hidden_act, - is_dropout=self.is_dropout, - conditional_size=self.conditional_size, - **get_kw(BertLayer, kwargs), - ) - self.encoderLayer = nn.ModuleList( - [copy.deepcopy(layer) for _ in range(self.num_hidden_layers)] - ) - - for i in range(1, self.num_hidden_layers): - self.encoderLayer[ - i - ].multiHeadAttention.relative_positions_encoding.weight = self.encoderLayer[ - 0 - ].multiHeadAttention.relative_positions_encoding.weight - self.final_layer_norm = LayerNorm( - self.hidden_size, - eps=1e-12, - conditional_size=self.conditional_size, - bias=False, - mode="rmsnorm", - ) - self.dropout = nn.Dropout(self.dropout_rate) - - def apply_final_layers(self, inputs): - hidden_states = super().apply_final_layers(inputs) - return self.dropout(self.final_layer_norm([hidden_states])) - - def load_variable(self, state_dict, name, prefix=""): - variable = state_dict[name] - if name in {"encoder.embed_tokens.weight", "shared.weight"}: - return self.load_embeddings(variable) - else: - return variable - - def variable_mapping(self, prefix=""): - mapping = { - f"{prefix}embeddings.word_embeddings.weight": "encoder.embed_tokens.weight", - f"{prefix}encoderLayer.0.multiHeadAttention.relative_positions_encoding.weight": "encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight", - f"{prefix}final_layer_norm.weight": "encoder.final_layer_norm.weight", - } - for i in range(self.num_hidden_layers): - mapping.update( - { - f"{prefix}encoderLayer.{i}.multiHeadAttention.q.weight": f"encoder.block.{i}.layer.0.SelfAttention.q.weight", - f"{prefix}encoderLayer.{i}.multiHeadAttention.k.weight": f"encoder.block.{i}.layer.0.SelfAttention.k.weight", - f"{prefix}encoderLayer.{i}.multiHeadAttention.v.weight": f"encoder.block.{i}.layer.0.SelfAttention.v.weight", - f"{prefix}encoderLayer.{i}.multiHeadAttention.o.weight": f"encoder.block.{i}.layer.0.SelfAttention.o.weight", - f"{prefix}encoderLayer.{i}.layerNorm1.weight": f"encoder.block.{i}.layer.0.layer_norm.weight", - f"{prefix}encoderLayer.{i}.feedForward.outputDense.weight": f"encoder.block.{i}.layer.1.DenseReluDense.wo.weight", - f"{prefix}encoderLayer.{i}.layerNorm2.weight": f"encoder.block.{i}.layer.1.layer_norm.weight", - } - ) - - if self.version.endswith("t5.1.0"): - mapping.update( - { - f"{prefix}encoderLayer.{i}.feedForward.intermediateDense.weight": f"encoder.block.{i}.layer.1.DenseReluDense.wi.weight" - } - ) - elif self.version.endswith("t5.1.1"): - mapping.update( - { - f"{prefix}encoderLayer.{i}.feedForward.intermediateDense.weight": f"encoder.block.{i}.layer.1.DenseReluDense.wi_0.weight", - f"{prefix}encoderLayer.{i}.feedForward.intermediateDense1.weight": f"encoder.block.{i}.layer.1.DenseReluDense.wi_1.weight", - } - ) - return mapping - - -class T5_Decoder(Decoder): - @insert_arguments(version="t5.1.0") - def __init__(self, *args, **kwargs): - kwargs.update( - { - "p_bias": "t5_relative", - "relative_attention_num_buckets": kwargs.get( - "relative_attention_num_buckets" - ), - "version": self.version, - "bias": False, - "norm_mode": "rmsnorm", - } - ) - super().__init__(*args, **kwargs) - del self.embeddings.layerNorm - - layer = T5Layer( - self.hidden_size, - self.num_attention_heads, - self.dropout_rate, - self.attention_probs_dropout_prob, - self.intermediate_size, - self.hidden_act, - is_dropout=self.is_dropout, - conditional_size=self.conditional_size, - is_decoder=True, - **get_kw(BertLayer, kwargs), - ) - self.decoderLayer = nn.ModuleList( - [copy.deepcopy(layer) for _ in range(self.num_hidden_layers)] - ) - - for i in range(1, self.num_hidden_layers): - self.decoderLayer[ - i - ].multiHeadAttention.relative_positions_encoding.weight = self.decoderLayer[ - 0 - ].multiHeadAttention.relative_positions_encoding.weight - self.final_layer_norm = LayerNorm( - self.hidden_size, - eps=1e-12, - conditional_size=self.conditional_size, - bias=False, - mode="rmsnorm", - ) - self.dropout = nn.Dropout(self.dropout_rate) - - def apply_final_layers(self, inputs): - inputs[0][1] = self.dropout( - self.final_layer_norm([inputs[0][1]]) - ) - return super().apply_final_layers(inputs) - - def load_variable(self, state_dict, name, prefix=""): - variable = state_dict[name] - if name in {f"decoder.embed_tokens.weight", "lm_head.weight", "shared.weight"}: - return self.load_embeddings(variable) - else: - return variable - - def variable_mapping(self, prefix=""): - mapping = { - f"{prefix}embeddings.word_embeddings.weight": "decoder.embed_tokens.weight", - f"{prefix}decoderLayer.0.multiHeadAttention.relative_positions_encoding.weight": "decoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight", - f"{prefix}final_layer_norm.weight": "decoder.final_layer_norm.weight", - f"{prefix}final_dense.weight": "lm_head.weight", - } - - for i in range(self.num_hidden_layers): - mapping.update( - { - f"{prefix}decoderLayer.{i}.multiHeadAttention.q.weight": f"decoder.block.{i}.layer.0.SelfAttention.q.weight", - f"{prefix}decoderLayer.{i}.multiHeadAttention.k.weight": f"decoder.block.{i}.layer.0.SelfAttention.k.weight", - f"{prefix}decoderLayer.{i}.multiHeadAttention.v.weight": f"decoder.block.{i}.layer.0.SelfAttention.v.weight", - f"{prefix}decoderLayer.{i}.multiHeadAttention.o.weight": f"decoder.block.{i}.layer.0.SelfAttention.o.weight", - f"{prefix}decoderLayer.{i}.layerNorm1.weight": f"decoder.block.{i}.layer.0.layer_norm.weight", - f"{prefix}decoderLayer.{i}.crossAttention.q.weight": f"decoder.block.{i}.layer.1.EncDecAttention.q.weight", - f"{prefix}decoderLayer.{i}.crossAttention.k.weight": f"decoder.block.{i}.layer.1.EncDecAttention.k.weight", - f"{prefix}decoderLayer.{i}.crossAttention.v.weight": f"decoder.block.{i}.layer.1.EncDecAttention.v.weight", - f"{prefix}decoderLayer.{i}.crossAttention.o.weight": f"decoder.block.{i}.layer.1.EncDecAttention.o.weight", - f"{prefix}decoderLayer.{i}.layerNorm3.weight": f"decoder.block.{i}.layer.1.layer_norm.weight", - f"{prefix}decoderLayer.{i}.feedForward.outputDense.weight": f"decoder.block.{i}.layer.2.DenseReluDense.wo.weight", - f"{prefix}decoderLayer.{i}.layerNorm2.weight": f"decoder.block.{i}.layer.2.layer_norm.weight", - } - ) - - if self.version.endswith("t5.1.0"): - mapping.update( - { - f"{prefix}decoderLayer.{i}.feedForward.intermediateDense.weight": f"decoder.block.{i}.layer.2.DenseReluDense.wi.weight" - } - ) - elif self.version.endswith("t5.1.1"): - mapping.update( - { - f"{prefix}decoderLayer.{i}.feedForward.intermediateDense.weight": f"decoder.block.{i}.layer.2.DenseReluDense.wi_0.weight", - f"{prefix}decoderLayer.{i}.feedForward.intermediateDense1.weight": f"decoder.block.{i}.layer.2.DenseReluDense.wi_1.weight", - } - ) - return mapping - - -class T5(Transformer): - @delete_arguments("with_pool", "with_mlm", "with_nsp") - def __init__(self, *args, tie_emb_src_tgt_weight=True, **kwargs): - super(T5, self).__init__(*args, **kwargs) - self.tie_emb_src_tgt_weight = tie_emb_src_tgt_weight - - # encoder - self.encoder = T5_Encoder(*args, **kwargs) - self.encoder.build(**kwargs) - - # decoder - self.decoder = T5_Decoder(*args, **kwargs) - self.decoder.build(**kwargs) - - def load_variable(self, state_dict, name, prefix=""): - variable = state_dict[name] - if name in { - "shared.weight", - "encoder.embed_tokens.weight", - "decoder.embed_tokens.weight", - "lm_head.weight", - }: - return self.load_embeddings(variable) - else: - return variable - - def variable_mapping(self, prefix=""): - mapping = self.encoder.variable_mapping(prefix="encoder.") - mapping.update(self.decoder.variable_mapping(prefix="decoder.")) - if self.tie_emb_src_tgt_weight: - mapping.update( - { - "encoder.embeddings.word_embeddings.weight": "shared.weight", - "decoder.embeddings.word_embeddings.weight": "shared.weight", - } - ) - return mapping - - -class GPT(LM_Mask, BERT): - @insert_arguments(final_activation="softmax") - @delete_arguments("with_pool", "with_mlm", "with_nsp") - def __init__(self, max_position, **kwargs): - super(GPT, self).__init__(max_position, **kwargs) - del self.embeddings.layerNorm - self.dense = nn.Linear(self.hidden_size, self.vocab_size, bias=False) - self.dense.weight = self.embeddings.word_embeddings.weight - self.final_activation = get_activation(self.final_activation) - - def apply_final_layers(self, inputs): - hidden_state = super().apply_final_layers(inputs) - logit = self.dense(hidden_state) - return self.final_activation(logit) - - def load_variable(self, state_dict, name): - return super(GPT, self).load_variable(state_dict, name, prefix="gpt") - - def variable_mapping(self): - mapping = super(GPT, self).variable_mapping(prefix="gpt") - return mapping - - -class GPT2(LM_Mask, BERT): - - @insert_arguments(final_activation="softmax") - @delete_arguments("with_pool", "with_mlm", "with_nsp") - def __init__(self, max_position, **kwargs): - super(GPT2, self).__init__(max_position, **kwargs) - del self.embeddings.layerNorm - layer = self.Gpt2Layer( - self.hidden_size, - self.num_attention_heads, - self.dropout_rate, - self.attention_probs_dropout_prob, - self.intermediate_size, - self.hidden_act, - is_dropout=self.is_dropout, - conditional_size=self.conditional_size, - ) - self.encoderLayer = nn.ModuleList( - [ - copy.deepcopy(layer) - if layer_id in self.keep_hidden_layers - else Identity() - for layer_id in range(self.num_hidden_layers) - ] - ) - self.LayerNormFinal = LayerNorm( - self.hidden_size, eps=1e-12, conditional_size=self.conditional_size - ) - self.dense = nn.Linear(self.hidden_size, self.vocab_size, bias=False) - self.dense.weight = self.embeddings.word_embeddings.weight - self.final_activation = get_activation(self.final_activation) - - def apply_final_layers(self, inputs): - hidden_state = super().apply_final_layers(inputs) - logit = self.dense(self.LayerNormFinal([hidden_state])) - return self.final_activation(logit) - - def load_variable(self, state_dict, name): - return super(GPT2, self).load_variable(state_dict, name, prefix="gpt2") - - def variable_mapping(self): - mapping = super(GPT2, self).variable_mapping(prefix="gpt2") - mapping.update( - { - "LayerNormFinal.weight": "gpt2.LayerNormFinal.weight", - "LayerNormFinal.bias": "gpt2.LayerNormFinal.bias", - } - ) - return mapping - - class Gpt2Layer(BertLayer): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - def forward( - self, - hidden_states, - attention_mask, - conditional_emb=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - ): - x = self.layerNorm1((hidden_states, conditional_emb)) - self_attn_output = self.multiHeadAttention(x, attention_mask) - hidden_states = hidden_states + self.dropout1(self_attn_output) - x = self.layerNorm2((hidden_states, conditional_emb)) - ffn_output = self.feedForward(x) - hidden_states = hidden_states + self.dropout2(ffn_output) - return hidden_states - - -class GPT2_ML(LM_Mask, BERT): - @insert_arguments(final_activation="softmax") - @delete_arguments("with_pool", "with_mlm", "with_nsp") - def __init__(self, max_position, **kwargs): - super().__init__(max_position, **kwargs) - layer = self.Gpt2MlLayer( - self.hidden_size, - self.num_attention_heads, - self.dropout_rate, - self.attention_probs_dropout_prob, - self.intermediate_size, - self.hidden_act, - is_dropout=self.is_dropout, - conditional_size=self.conditional_size, - ) - self.encoderLayer = nn.ModuleList( - [ - copy.deepcopy(layer) - if layer_id in self.keep_hidden_layers - else Identity() - for layer_id in range(self.num_hidden_layers) - ] - ) - self.dense = nn.Linear(self.hidden_size, self.vocab_size, bias=False) - self.dense.weight = self.embeddings.word_embeddings.weight - self.final_activation = get_activation(self.final_activation) - - def apply_final_layers(self, inputs): - hidden_state = super().apply_final_layers(inputs) - logit = self.dense(hidden_state) - return self.final_activation(logit) - - def load_variable(self, state_dict, name): - return super(GPT2_ML, self).load_variable(state_dict, name, prefix="gpt2_ml") - - def variable_mapping(self): - mapping = super(GPT2_ML, self).variable_mapping(prefix="gpt2_ml") - return mapping - - class Gpt2MlLayer(BertLayer): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - def forward( - self, - hidden_states, - attention_mask, - conditional_emb=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - ): - self_attn_output = self.multiHeadAttention(hidden_states, attention_mask) - hidden_states = hidden_states + self.dropout1(self_attn_output) - x = self.layerNorm1((hidden_states, conditional_emb)) - ffn_output = self.feedForward(x) - hidden_states = hidden_states + self.dropout2(ffn_output) - hidden_states = self.layerNorm2((hidden_states, conditional_emb)) - return hidden_states - - -class Transformer_XL(BERT): - @delete_arguments("with_pool", "with_nsp", "with_mlm") - @insert_arguments(with_lm=False) - def __init__(self, *args, mem_len=0, same_length=False, clamp_len=-1, **kwargs): - # p_bias来控制embedding阶段无pos_embedding - kwargs.update({"p_bias": "other_relative"}) - super().__init__(*args, **kwargs) - self.mem_len, self.same_length, self.clamp_len = mem_len, same_length, clamp_len - self.attn_type = kwargs.get("attn_type", 0) - - # embedding - if kwargs.get("adaptive_embedding"): - cutoffs, div_val, sample_softmax = ( - kwargs.get("cutoffs", []), - kwargs.get("div_val", 1), - kwargs.get("sample_softmax", False), - ) - self.embeddings = AdaptiveEmbedding( - self.vocab_size, - self.embedding_size, - self.hidden_size, - cutoffs, - div_val, - sample_softmax, - **get_kw(AdaptiveEmbedding, kwargs), - ) - else: - self.embeddings = nn.Embedding(self.vocab_size, self.embedding_size) - self.pos_embeddings = XlnetPositionsEncoding(self.embedding_size) - self.dropout = nn.Dropout(self.dropout_rate) - - if not kwargs.get("untie_r"): - self.r_w_bias = nn.Parameter( - torch.FloatTensor(self.num_attention_heads, self.attention_head_size) - ) - self.r_r_bias = nn.Parameter( - torch.FloatTensor(self.num_attention_heads, self.attention_head_size) - ) - if self.segment_vocab_size > 0: - self.r_s_bias = nn.Parameter( - torch.FloatTensor( - self.num_attention_heads, self.attention_head_size - ) - ) - else: - self.r_w_bias, self.r_r_bias = None, None - self.r_s_bias = None - - # transformer block - layer = XlnetLayer( - self.hidden_size, - self.num_attention_heads, - self.dropout_rate, - self.attention_probs_dropout_prob, - self.intermediate_size, - self.hidden_act, - is_dropout=self.is_dropout, - conditional_size=self.conditional_size, - r_w_bias=self.r_w_bias, - r_r_bias=self.r_r_bias, - r_s_bias=None, - **get_kw(BertLayer, kwargs), - ) - self.encoderLayer = nn.ModuleList( - [ - copy.deepcopy(layer) - if layer_id in self.keep_hidden_layers - else Identity() - for layer_id in range(self.num_hidden_layers) - ] - ) - - # 映射 - if self.with_lm: - self.dense = nn.Linear(self.hidden_size, self.vocab_size, bias=True) - - def init_mems(self, bsz): - if isinstance(self.mem_len, (int, float)) and (self.mem_len > 0): - mems = [] - param = next(self.parameters()) - for _ in range(self.num_hidden_layers + 1): - empty = torch.zeros( - bsz, - self.mem_len, - self.hidden_size, - dtype=param.dtype, - device=param.device, - ) - mems.append(empty) - - return mems - else: - return None - - def _update_mems(self, hids, mlen, qlen): - # does not deal with None - if self.mems is None: - return None - # mems is not None - assert len(hids) == len(self.mems), "len(hids) != len(mems)" - # There are `mlen + qlen` steps that can be cached into mems - with torch.no_grad(): - new_mems = [] - end_idx = mlen + max(0, qlen) - beg_idx = max(0, end_idx - self.mem_len) - for i in range(len(hids)): - cat = torch.cat([self.mems[i], hids[i]], dim=1) - new_mems.append(cat[:, beg_idx:end_idx].detach()) - self.mems = new_mems - - def relative_positional_encoding(self, qlen, klen, device): - pos_seq = torch.arange(klen - 1, -1, -1.0, device=device, dtype=torch.long) - if self.clamp_len > 0: - pos_seq.clamp_(max=self.clamp_len) - pos_emb = self.dropout(self.pos_embeddings(pos_seq)) - return pos_emb - - def create_mask(self, word_emb, qlen, klen, mlen): - - if self.same_length: - all_ones = word_emb.new_ones(qlen, klen) - mask_len = klen - self.mem_len - mask_shift_len = qlen - mask_len if mask_len > 0 else qlen - attention_mask = ( - 1 - - ( - torch.triu(all_ones, 1 + mlen) - + torch.tril(all_ones, -mask_shift_len) - ).byte() - ) - else: - attention_mask = torch.tril( - word_emb.new_ones(qlen, klen), diagonal=mlen - ).byte() - attention_mask = attention_mask[None, None, :, :] - return attention_mask - - def apply_embeddings(self, inputs): - - self.mems = self.init_mems(inputs[0].size(0)) - - - word_emb = self.dropout(self.embeddings(inputs[0])) - index_ = 1 - btz, qlen = inputs[0].shape[:2] - mlen = self.mems[0].size(1) if self.mems is not None else 0 - klen = mlen + qlen - - pos_emb = self.relative_positional_encoding(qlen, klen, word_emb.device) - - if self.segment_vocab_size > 0: - segment_ids = inputs[index_] - if mlen > 0: - mem_pad = torch.zeros( - [btz, mlen], dtype=torch.long, device=word_emb.device - ) - cat_ids = torch.cat([mem_pad, segment_ids], dim=1) - else: - cat_ids = segment_ids - segment_ids = (segment_ids[:, :, None] != cat_ids[:, None]).long() - index_ += 1 - else: - segment_ids = None - - if self.attn_type in {"uni", 0}: - attention_mask = self.create_mask(word_emb, qlen, klen, mlen) - elif self.attn_type == "bi": - attention_mask = ( - (inputs[0] != self.token_pad_ids).long().unsqueeze(1).unsqueeze(2) - ) - non_tgt_mask = torch.eye(qlen).to(attention_mask)[None, None, :, :] - non_tgt_mask = ((1 - attention_mask - non_tgt_mask) <= 0).long() - - return [word_emb, segment_ids, pos_emb, non_tgt_mask, None] - - def apply_main_layers(self, inputs): - hidden_states, segment_ids, pos_emb, attention_mask, conditional_emb = inputs[ - :5 - ] - encoded_layers = [hidden_states] - - layer_inputs = [ - hidden_states, - segment_ids, - pos_emb, - attention_mask, - None, - conditional_emb, - ] - for i, layer_module in enumerate(self.encoderLayer): - mems_i = None if self.mems is None else self.mems[i] - layer_inputs[-2] = mems_i - layer_inputs = self.apply_on_layer_begin(i, layer_inputs) - hidden_states = layer_module(*layer_inputs) - layer_inputs[0] = hidden_states - layer_inputs = self.apply_on_layer_end(i, layer_inputs) - encoded_layers.append(hidden_states) - - hidden_states = self.dropout(hidden_states) - qlen = inputs[0].size(1) - mlen = self.mems[0].size(0) if self.mems is not None else 0 - self._update_mems(encoded_layers, mlen, qlen) - - if not self.output_all_encoded_layers: - encoded_layers = encoded_layers[:1] + [hidden_states] - return [encoded_layers, conditional_emb] - - def load_variable(self, state_dict, name, prefix=""): - if (self.keep_tokens is not None) or (self.compound_tokens is not None): - raise ValueError( - "Custom keep_tokens and compound_tokens is not yet supported in Transformer_XL" - ) - return state_dict[name] - - def variable_mapping(self, prefix=""): - return {k: k for k, v in self.named_parameters()} - - -class XLNET(Transformer_XL): - - def __init__(self, *args, bi_data=False, **kwargs): - self.attn_type = kwargs.get("attn_type", "bi") - self.bi_data = bi_data - kwargs["rel_shift_opt"] = "xlnet" - super().__init__(*args, **kwargs) - - def relative_positional_encoding(self, qlen, klen, device): - if self.attn_type == "bi": - beg, end = klen, -qlen - elif self.attn_type == "uni": - beg, end = klen, -1 - else: - raise ValueError(f"Unknown `attn_type` {self.attn_type}.") - - pos_seq = torch.arange(beg, end, -1.0, device=device, dtype=torch.long) - if self.clamp_len > 0: - pos_seq.clamp_(max=self.clamp_len) - fwd_pos_emb = self.pos_embeddings(pos_seq) - - if self.bi_data: - pos_seq = torch.arange(-beg, -end, -1.0, device=device, dtype=torch.long) - if self.clamp_len > 0: - pos_seq.clamp_(max=self.clamp_len) - bwd_pos_emb = self.pos_embeddings(pos_seq) - pos_emb = torch.cat([fwd_pos_emb, bwd_pos_emb], dim=0) - else: - pos_emb = fwd_pos_emb - - pos_emb = self.dropout(pos_emb) - return pos_emb - - def apply_final_layers(self, inputs): - hidden_state = super().apply_final_layers(inputs) - if self.with_lm: - return [hidden_state, self.dense(hidden_state)] - else: - return hidden_state - - def load_variable(self, state_dict, name, prefix="transformer"): - variable = state_dict[name] - if name in { - f"{prefix}.word_embedding.weight", - "lm_loss.weight", - "lm_loss.bias", - }: - return self.load_embeddings(variable) - elif re.search("rel_attn\.(q|k|v|r)$", name): - return variable.reshape(variable.shape[0], -1).T - # elif re.search('rel_attn\.(o|seg_embed)$', name): - elif re.search("rel_attn\.(o)$", name): - return variable.reshape(variable.shape[0], -1) - else: - return variable - - def variable_mapping(self, prefix="transformer"): - mapping = { - "embeddings.weight": f"{prefix}.word_embedding.weight", - "dense.weight": "lm_loss.weight", - "dense.bias": "lm_loss.bias", - } - for i in range(self.num_hidden_layers): - prefix_i = f"{prefix}.layer.%d." % i - mapping.update( - { - f"encoderLayer.{i}.multiHeadAttention.q.weight": prefix_i - + "rel_attn.q", - f"encoderLayer.{i}.multiHeadAttention.k.weight": prefix_i - + "rel_attn.k", - f"encoderLayer.{i}.multiHeadAttention.v.weight": prefix_i - + "rel_attn.v", - f"encoderLayer.{i}.multiHeadAttention.o.weight": prefix_i - + "rel_attn.o", - f"encoderLayer.{i}.multiHeadAttention.r.weight": prefix_i - + "rel_attn.r", - f"encoderLayer.{i}.multiHeadAttention.r_r_bias": prefix_i - + "rel_attn.r_r_bias", - f"encoderLayer.{i}.multiHeadAttention.r_s_bias": prefix_i - + "rel_attn.r_s_bias", - f"encoderLayer.{i}.multiHeadAttention.r_w_bias": prefix_i - + "rel_attn.r_w_bias", - # f'encoderLayer.{i}.multiHeadAttention.seg_embed.weight': prefix_i + 'rel_attn.seg_embed', - f"encoderLayer.{i}.multiHeadAttention.seg_embed": prefix_i - + "rel_attn.seg_embed", - f"encoderLayer.{i}.layerNorm1.weight": prefix_i - + "rel_attn.layer_norm.weight", - f"encoderLayer.{i}.layerNorm1.bias": prefix_i - + "rel_attn.layer_norm.bias", - f"encoderLayer.{i}.feedForward.intermediateDense.weight": prefix_i - + "ff.layer_1.weight", - f"encoderLayer.{i}.feedForward.intermediateDense.bias": prefix_i - + "ff.layer_1.bias", - f"encoderLayer.{i}.feedForward.outputDense.weight": prefix_i - + "ff.layer_2.weight", - f"encoderLayer.{i}.feedForward.outputDense.bias": prefix_i - + "ff.layer_2.bias", - f"encoderLayer.{i}.layerNorm2.weight": prefix_i - + "ff.layer_norm.weight", - f"encoderLayer.{i}.layerNorm2.bias": prefix_i - + "ff.layer_norm.bias", - } - ) - - return mapping - - -def build_transformer_model( - config_path=None, - checkpoint_path=None, - model="bert", - application="encoder", - **kwargs, -): - - configs = {} - if config_path is not None: - configs.update(json.load(open(config_path))) - configs.update(kwargs) - if "max_position" not in configs: - configs["max_position"] = configs.get("max_position_embeddings", 512) - if "dropout_rate" not in configs: - configs["dropout_rate"] = configs.get("hidden_dropout_prob") - if "segment_vocab_size" not in configs: - configs["segment_vocab_size"] = configs.get("type_vocab_size", 2) - - models = { - "bert": BERT, - "roberta": BERT, - "albert": ALBERT, - "albert_unshared": ALBERT_Unshared, - "nezha": NEZHA, - "roformer": RoFormer, - "roformer_v2": RoFormerV2, - "gau_alpha": GAU_alpha, - "electra": ELECTRA, - "encoder": Encoder, - "decoder": Decoder, - "transformer": Transformer, - "bart": BART, - "gpt": GPT, - "gpt2": GPT2, - "gpt2_ml": GPT2_ML, - "t5": T5, - "t5_encoder": T5_Encoder, - "t5_decoder": T5_Decoder, - "t5.1.0": T5, - "t5.1.0_encoder": T5_Encoder, - "t5.1.0_decoder": T5_Decoder, - "t5.1.1": T5, - "t5.1.1_encoder": T5_Encoder, - "t5.1.1_decoder": T5_Decoder, - "mt5.1.1": T5, - "mt5.1.1_encoder": T5_Encoder, - "mt5.1.1_decoder": T5_Decoder, - "transformer_xl": Transformer_XL, - "xlnet": XLNET, - } - - if isinstance(model, str): - MODEL = models[model.lower()] - if model.endswith("t5.1.1"): - configs["version"] = model - elif isinstance(model, type) and issubclass( - model, BERT_BASE - ): - MODEL = model - else: - raise ValueError('"model" args type should be string or nn.Module') - - application = application.lower() - if application in ["lm", "unilm"] and model in [ - "electra", - "t5", - ]: - raise ValueError( - f'"{model}" model can not be used as "{application}" application.\n' - ) - - if application == "lm": - MODEL = extend_with_language_model(MODEL) - elif application == "unilm": - MODEL = extend_with_unified_language_model(MODEL) - - transformer = MODEL(**configs) - transformer.build(**configs) - transformer.apply(transformer.init_model_weights) - - if checkpoint_path is not None: - transformer.load_weights_from_pytorch_checkpoint(checkpoint_path) - transformer.configs = configs - return transformer +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import copy +import json +import re +import warnings + +import torch +import torch.nn as nn +from bert4torch.activations import get_activation +from bert4torch.layers import ( + AdaptiveEmbedding, + BertEmbeddings, + BertLayer, + GatedAttentionUnit, + Identity, + LayerNorm, + T5Layer, + XlnetLayer, + XlnetPositionsEncoding, +) +from bert4torch.snippets import ( + FGM, + PGD, + VAT, + EarlyStopping, + IterDataset, + ProgbarLogger, + delete_arguments, + get_kw, + insert_arguments, + metric_mapping, + search_layer, + take_along_dim, +) + + +class BaseModel(nn.Module): + def __init__(self): + super(BaseModel, self).__init__() + ( + self.global_step, + self.local_step, + self.total_steps, + self.epoch, + self.train_dataloader, + ) = (0, 0, 0, 0, None) + self.callbacks = [] + + def compile( + self, + loss, + optimizer, + scheduler=None, + max_grad_norm=None, + use_amp=False, + metrics=None, + adversarial_train={"name": ""}, + ): + self.criterion = loss + self.optimizer = optimizer + self.scheduler = scheduler + self.max_grad_norm = max_grad_norm + self.use_amp = use_amp + if use_amp: + assert adversarial_train["name"] not in { + "vat", + "gradient_penalty", + }, "Amp and adversarial_train both run is not supported in current version" + from torch.cuda.amp import autocast + + self.autocast = autocast + self.scaler = torch.cuda.amp.GradScaler() + + if metrics is None: + metrics = [] + self.metrics = ["loss"] + [i for i in metrics if i != "loss"] + + # 对抗训练 + self.adversarial = adversarial_train + self.adversarial_initialize() + + def adversarial_initialize(self): + assert self.adversarial["name"] in { + "", + "fgm", + "pgd", + "vat", + "gradient_penalty", + }, "adversarial_train support fgm, pgd, vat and gradient_penalty mode" + self.adversarial["epsilon"] = self.adversarial.get("epsilon", 1.0) + self.adversarial["emb_name"] = self.adversarial.get( + "emb_name", "word_embeddings" + ) + + if self.adversarial["name"] == "fgm": + self.ad_train = FGM(self) + elif self.adversarial["name"] == "pgd": + self.adversarial["K"] = self.adversarial.get("K", 3) # 步数 + self.adversarial["alpha"] = self.adversarial.get("alpha", 0.3) # 学习率 + self.ad_train = PGD(self) + elif self.adversarial["name"] == "gradient_penalty": + pass + elif self.adversarial["name"] == "vat": + self.adversarial["K"] = self.adversarial.get("K", 3) + self.adversarial["noise_var"] = self.adversarial.get( + "noise_var", 1e-5 + ) + self.adversarial["noise_gamma"] = self.adversarial.get( + "noise_gamma", 1e-6 + ) + self.adversarial["adv_step_size"] = self.adversarial.get( + "adv_step_size", 1e-3 + ) + self.adversarial["adv_alpha"] = self.adversarial.get( + "adv_alpha", 1 + ) + self.adversarial["norm_type"] = self.adversarial.get( + "norm_type", "l2" + ) + self.ad_train = VAT(self, **self.adversarial) + + def adversarial_training( + self, train_X, train_y, output, loss, loss_detail, grad_accumulation_steps + ): + """对抗训练""" + if self.adversarial["name"] == "fgm": + self.ad_train.attack(**self.adversarial) + output, loss, loss_detail = self.train_step( + train_X, train_y, grad_accumulation_steps + ) + loss.backward() + self.ad_train.restore(**self.adversarial) + elif self.adversarial["name"] == "pgd": + self.ad_train.backup_grad() + for t in range(self.adversarial["K"]): + self.ad_train.attack(**self.adversarial, is_first_attack=(t == 0)) + if t != self.adversarial["K"] - 1: + self.optimizer.zero_grad() + else: + self.ad_train.restore_grad() + output, loss, loss_detail = self.train_step( + train_X, train_y, grad_accumulation_steps + ) + loss.backward() + self.ad_train.restore(**self.adversarial) + elif self.adversarial["name"] == "gradient_penalty": + para = search_layer(self, self.adversarial["emb_name"], retrun_first=True) + gp = (para.grad**2).sum() + loss += 0.5 * gp * self.adversarial["epsilon"] + loss.backward() + elif self.adversarial["name"] == "vat": + logit = output[0] if isinstance(output, (list, tuple)) else output + adv_loss = self.ad_train.virtual_adversarial_training(train_X, logit) + loss_detail.update({"loss_sup": loss.item(), "loss_unsup": adv_loss}) + loss += adv_loss if adv_loss else 0 + loss.backward() + + return loss, loss_detail + + def train_step(self, train_X, train_y, grad_accumulation_steps): + + def args_segmentate(train_X): + if isinstance(train_X, torch.Tensor): + pass + elif isinstance(self, (BaseModelDP, BaseModelDDP)): + if self.module.forward.__code__.co_argcount >= 3: + return True + elif self.forward.__code__.co_argcount >= 3: + return True + return False + + if self.use_amp: + with self.autocast(): + output = ( + self.forward(*train_X) + if args_segmentate(train_X) + else self.forward(train_X) + ) + loss_detail = self.criterion(output, train_y) + else: + output = ( + self.forward(*train_X) + if args_segmentate(train_X) + else self.forward(train_X) + ) + loss_detail = self.criterion(output, train_y) + + if isinstance(loss_detail, torch.Tensor): + loss = loss_detail + loss_detail = {} + elif isinstance(loss_detail, dict): + loss = loss_detail["loss"] + del loss_detail["loss"] + elif isinstance(loss_detail, (tuple, list)): + loss = loss_detail[0] + loss_detail = { + f"loss{i}": v for i, v in enumerate(loss_detail[1:], start=1) + } + else: + raise ValueError("Return loss only support Tensor/dict/tuple/list format") + # 梯度累积 + loss = loss / grad_accumulation_steps if grad_accumulation_steps > 1 else loss + return output, loss, loss_detail + + def callback_fun(self, mode, logs={}): + if ( + isinstance(self, BaseModelDDP) + and self.master_rank != torch.distributed.get_rank() + ): + return + + if mode == "train_begin": + for callback in self.callbacks: + callback.on_train_begin() + elif mode == "epoch_begin": + for callback in self.callbacks: + callback.on_epoch_begin(self.global_step, self.epoch, logs) + elif mode == "batch_begin": + for callback in self.callbacks: + callback.on_batch_begin(self.global_step, self.local_step, logs) + elif mode == "batch_end": + for callback in self.callbacks: + callback.on_batch_end(self.global_step, self.local_step, logs) + elif mode == "epoch_end": + for callback in self.callbacks: + callback.on_epoch_end(self.global_step, self.epoch, logs) + elif mode == "train_end": + for callback in self.callbacks: + callback.on_train_end() + elif mode == "dataloader_end": + for callback in self.callbacks: + callback.on_dataloader_end() + + def fit( + self, + train_dataloader, + steps_per_epoch=None, + epochs=1, + grad_accumulation_steps=1, + callbacks=[], + ): + if isinstance(train_dataloader.dataset, IterDataset): + assert ( + steps_per_epoch is not None + ), "IterDataset should specify steps_per_epoch" + steps_per_epoch = ( + len(train_dataloader) if steps_per_epoch is None else steps_per_epoch + ) + self.total_steps = steps_per_epoch * epochs + self.global_step = 0 + self.train_dataloader = train_dataloader + train_dataloader_iter = iter(self.train_dataloader) + + self.callbacks = [ProgbarLogger(epochs, steps_per_epoch, self.metrics)] + ( + callbacks if isinstance(callbacks, (list, tuple)) else [callbacks] + ) + self.callback_fun("train_begin") + + self.bti = 0 + for epoch in range(epochs): + if isinstance( + self.train_dataloader.sampler, + torch.utils.data.distributed.DistributedSampler, + ): + self.train_dataloader.sampler.set_epoch(epoch) + self.epoch = epoch + self.callback_fun("epoch_begin") + for local_step in range(steps_per_epoch): + self.local_step = local_step + try: + batch = next(train_dataloader_iter) + except StopIteration: + self.callback_fun( + "dataloader_end" + ) + train_dataloader_iter = iter( + self.train_dataloader + ) + self.bti = 0 + batch = next(train_dataloader_iter) + train_X, train_y = batch + + if isinstance(train_X, (list, tuple)): + if isinstance(train_X[0], (list, tuple)): + btz = train_X[0][0].size(0) + else: + btz = train_X[0].size(0) + elif isinstance(train_X, torch.Tensor): + btz = train_X.size(0) + else: + raise ValueError("Input only support [list, tuple, tensor]") + logs = {"batch": self.local_step, "size": btz} + self.callback_fun("batch_begin", logs) + + self.train() + output, loss, loss_detail = self.train_step( + train_X, train_y, grad_accumulation_steps + ) + + retain_graph = ( + True + if self.adversarial["name"] in {"gradient_penalty", "vat"} + else False + ) + if self.use_amp: + scale_before_step = self.scaler.get_scale() + self.scaler.scale(loss).backward(retain_graph=retain_graph) + else: + loss.backward(retain_graph=retain_graph) + + loss, loss_detail = self.adversarial_training( + train_X, train_y, output, loss, loss_detail, grad_accumulation_steps + ) + + if (self.global_step + 1) % grad_accumulation_steps == 0: + skip_scheduler = False + if self.use_amp: + self.scaler.unscale_(self.optimizer) + if self.max_grad_norm is not None: + torch.nn.utils.clip_grad_norm_( + self.parameters(), self.max_grad_norm + ) + self.scaler.step(self.optimizer) + self.scaler.update() + skip_scheduler = self.scaler.get_scale() != scale_before_step + else: + if self.max_grad_norm is not None: + torch.nn.utils.clip_grad_norm_( + self.parameters(), self.max_grad_norm + ) + self.optimizer.step() + + self.optimizer.zero_grad() + if (self.scheduler is not None) and not skip_scheduler: + self.scheduler.step() + + # 添加log打印 + logs.update({"loss": loss.item()}) + logs_loss_detail = { + k: v.item() if isinstance(v, torch.Tensor) else v + for k, v in loss_detail.items() + } + logs.update(logs_loss_detail) + if self.global_step == 0: + self.callbacks[0].add_metrics( + list(logs_loss_detail.keys()), add_position=1 + ) + for metric in self.metrics: + tmp = metric_mapping(metric, output, train_y) + if tmp is not None: + logs[metric] = tmp + self.callback_fun("batch_end", logs) + + self.bti += 1 + self.global_step += 1 + self.callback_fun("epoch_end", logs) + callback_tmp = [ + callback_tmp + for callback_tmp in self.callbacks + if isinstance(callback_tmp, EarlyStopping) + ] + if callback_tmp and callback_tmp[0].stopped_epoch > 0: + break + self.callback_fun("train_end", logs) + + @torch.no_grad() + def predict(self, input_tensor_list, return_all=None): + self.eval() + if self.forward.__code__.co_argcount >= 3: + output = self.forward(*input_tensor_list) + else: + output = self.forward(input_tensor_list) + if return_all is None: + return output + elif ( + isinstance(output, (tuple, list)) + and isinstance(return_all, int) + and return_all < len(output) + ): + return output[return_all] + else: + raise ValueError("Return format error") + + def load_weights(self, load_path, strict=True, prefix=None): + state_dict = torch.load(load_path, map_location="cpu") + if prefix is None: + self.load_state_dict(state_dict, strict=strict) + else: + eval_str = ( + "self.variable_mapping()" + if prefix == "" + else f"self.{prefix}.variable_mapping()" + ) + mapping = {v: k for k, v in eval(eval_str).items()} + mapping = ( + mapping + if prefix == "" + else {k: f"{prefix}.{v}" for k, v in mapping.items()} + ) + state_dict_raw = {} + for k, v in state_dict.items(): + k = mapping.get(k, k) + state_dict_raw[k] = v + self.load_state_dict(state_dict_raw, strict=strict) + + def save_weights(self, save_path, prefix=None): + if prefix is None: + torch.save(self.state_dict(), save_path) + else: + eval_str = ( + "self.variable_mapping()" + if prefix == "" + else f"self.{prefix}.variable_mapping()" + ) + mapping = eval(eval_str) + mapping = ( + mapping + if prefix == "" + else {f"{prefix}.{k}": v for k, v in mapping.items()} + ) + state_dict_raw = {} + for k, v in self.state_dict().items(): + k = mapping.get(k, k) + state_dict_raw[k] = v + torch.save(state_dict_raw, save_path) + + +class BaseModelDP(BaseModel, nn.DataParallel): + + def __init__(self, *args, **kwargs): + nn.DataParallel.__init__(self, *args, **kwargs) + + +class BaseModelDDP(BaseModel, nn.parallel.DistributedDataParallel): + + def __init__(self, *args, master_rank=0, **kwargs): + self.master_rank = master_rank + nn.parallel.DistributedDataParallel.__init__(self, *args, **kwargs) + + +class BERT_BASE(BaseModel): + """模型基类""" + + def __init__( + self, + vocab_size, + hidden_size, + num_hidden_layers, + num_attention_heads, + intermediate_size, + hidden_act, + dropout_rate=None, + attention_probs_dropout_prob=None, + embedding_size=None, + attention_head_size=None, + attention_key_size=None, + initializer_range=0.02, + sequence_length=None, + keep_tokens=None, + compound_tokens=None, + residual_attention_scores=False, + ignore_invalid_weights=False, + keep_hidden_layers=None, + hierarchical_position=None, + **kwargs, + ): + super(BERT_BASE, self).__init__() + if keep_tokens is not None: + vocab_size = len(keep_tokens) + if compound_tokens is not None: + vocab_size += len(compound_tokens) + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.attention_head_size = ( + attention_head_size or self.hidden_size // self.num_attention_heads + ) + self.attention_key_size = attention_key_size or self.attention_head_size + self.intermediate_size = intermediate_size + self.dropout_rate = dropout_rate or 0 + self.attention_probs_dropout_prob = attention_probs_dropout_prob or 0 + self.hidden_act = hidden_act + self.embedding_size = embedding_size or hidden_size + self.initializer_range = initializer_range + self.sequence_length = sequence_length + self.keep_tokens = keep_tokens + self.compound_tokens = compound_tokens + self.attention_bias = None + self.position_bias = None + self.attention_scores = None + self.residual_attention_scores = residual_attention_scores + self.ignore_invalid_weights = ignore_invalid_weights + self.keep_hidden_layers = ( + set(range(num_hidden_layers)) + if keep_hidden_layers is None + else set(keep_hidden_layers) + ) + self.hierarchical_position = hierarchical_position + + def build( + self, + attention_caches=None, + layer_norm_cond=None, + layer_norm_cond_hidden_size=None, + layer_norm_cond_hidden_act=None, + additional_input_layers=None, + **kwargs, + ): + self.attention_caches = attention_caches or {} + self.output_all_encoded_layers = kwargs.get("output_all_encoded_layers", False) + + def forward(self, inputs): + # Embedding + outputs = self.apply_embeddings(inputs) + # Main + outputs = self.apply_main_layers(outputs) + # Final + outputs = self.apply_final_layers(outputs) + return outputs + + def init_model_weights(self, module): + if isinstance(module, (nn.Linear, nn.Embedding)) and ( + module.weight.requires_grad + ): + module.weight.data.normal_(mean=0.0, std=self.initializer_range) + elif isinstance(module, LayerNorm): + if ( + hasattr(module, "bias") and module.bias.requires_grad + ): + module.bias.data.zero_() + if hasattr(module, "weight") and module.weight.requires_grad: + module.weight.data.fill_(1.0) + if ( + isinstance(module, nn.Linear) + and (module.bias is not None) + and (module.bias.requires_grad) + ): + module.bias.data.zero_() + + def variable_mapping(self): + return {} + + def load_load_variable(self): + raise NotImplementedError + + def load_embeddings(self, embeddings): + if self.keep_tokens is not None: + embeddings = embeddings[self.keep_tokens] + + if self.compound_tokens is not None: + ext_embeddings = [] + for item in self.compound_tokens: + try: + ext_embeddings.append( + torch.mean(embeddings[item], 0) + * torch.ones_like(embeddings[item]) + ) + except IndexError: + ext_embeddings.append(torch.mean(embeddings, 0, keepdim=True)) + warnings.warn( + f"Initialize ext_embeddings from compound_tokens not in embedding index" + ) + embeddings = torch.cat([embeddings] + ext_embeddings, 0) + + return embeddings + + def load_pos_embeddings(self, embeddings): + if self.hierarchical_position is not None: + alpha = ( + 0.4 + if self.hierarchical_position is True + else self.hierarchical_position + ) + embeddings = embeddings - alpha * embeddings[:1] + embeddings = embeddings / (1 - alpha) + position_index = torch.arange(self.max_position)[:, None] + + embeddings_x = take_along_dim( + embeddings, + torch.div(position_index, embeddings.size(0), rounding_mode="trunc"), + dim=0, + ) + embeddings_y = take_along_dim( + embeddings, position_index % embeddings.size(0), dim=0 + ) + embeddings = alpha * embeddings_x + (1 - alpha) * embeddings_y + + return embeddings + + def load_weights_from_pytorch_checkpoint(self, checkpoint, mapping=None): + file_state_dict = torch.load(checkpoint, map_location="cpu") + mapping = mapping or self.variable_mapping() + parameters_set = set([i[0] for i in self.named_parameters()]) + + for layer_name in parameters_set: + if (layer_name in file_state_dict) and (layer_name not in mapping): + mapping.update({layer_name: layer_name}) + + state_dict_new = {} + for new_key, old_key in mapping.items(): + if new_key not in self.state_dict(): + continue + elif old_key in file_state_dict: + state_dict_new[new_key] = self.load_variable(file_state_dict, old_key) + elif (old_key not in file_state_dict) and (not self.ignore_invalid_weights): + print(f"[WARNIMG] {old_key} not found in pretrain models") + if new_key in parameters_set: + parameters_set.remove(new_key) + + if not self.ignore_invalid_weights: + for key in parameters_set: + print(f"[WARNIMG] Parameter {key} not loaded from pretrain models") + del file_state_dict + + self.load_state_dict(state_dict_new, strict=False) + + + def apply_embeddings(self, inputs): + raise NotImplementedError + + def apply_main_layers(self, inputs): + raise NotImplementedError + + def apply_final_layers(self, inputs): + raise NotImplementedError + + def apply_on_layer_begin(self, l_i, inputs): + + return inputs + + def apply_on_layer_end(self, l_i, inputs): + + return inputs + + def compute_attention_bias(self, inputs=None): + + return self.attention_bias + + def compute_position_bias(self, inputs=None): + + return self.position_bias + + def set_outputs(self, outputs): + + if not isinstance(outputs, list): + outputs = [outputs] + + outputs = outputs[:] + self.outputs = outputs + if len(outputs) > 1: + self.output = outputs + else: + self.output = outputs[0] + + +class LM_Mask(object): + + def compute_attention_bias(self, inputs=None): + seq_len = inputs[0].shape[1] + attention_bias = torch.tril( + torch.ones(seq_len, seq_len, dtype=torch.long, device=inputs[0].device), + diagonal=0, + ) + self.attention_bias = attention_bias.unsqueeze(0).unsqueeze(1) + return self.attention_bias + + +def extend_with_language_model(InputModel): + + class LanguageModel(LM_Mask, InputModel): + + def __init__(self, *args, **kwargs): + kwargs["with_mlm"] = kwargs.get("with_mlm") or True + super(LanguageModel, self).__init__(*args, **kwargs) + + return LanguageModel + + +class UniLM_Mask(object): + def compute_attention_bias(self, inputs=None): + segment_ids = inputs[1] + attention_bias = torch.cumsum(segment_ids, dim=1) + attention_bias = (attention_bias.unsqueeze(1)) <= (attention_bias.unsqueeze(2)) + self.attention_bias = attention_bias.unsqueeze(1).long() + + return self.attention_bias + + +def extend_with_unified_language_model(InputModel): + + class UnifiedLanguageModel(UniLM_Mask, InputModel): + + def __init__(self, *args, **kwargs): + kwargs["with_mlm"] = kwargs.get("with_mlm") or True + super(UnifiedLanguageModel, self).__init__(*args, **kwargs) + + return UnifiedLanguageModel + + +class BERT(BERT_BASE): + def __init__( + self, + max_position, + segment_vocab_size=2, + with_pool=False, + with_nsp=False, + with_mlm=False, + custom_position_ids=False, + custom_attention_mask=False, + shared_segment_embeddings=False, + layer_norm_cond=None, + layer_add_embs=None, + is_dropout=False, + token_pad_ids=0, + **kwargs, + ): + super(BERT, self).__init__(**kwargs) + self.max_position = max_position + self.segment_vocab_size = segment_vocab_size + self.with_pool = with_pool + self.with_nsp = with_nsp + self.with_mlm = with_mlm + self.custom_position_ids = custom_position_ids + self.custom_attention_mask = custom_attention_mask + self.shared_segment_embeddings = shared_segment_embeddings + self.is_dropout = is_dropout + self.token_pad_ids = token_pad_ids + if self.with_nsp and not self.with_pool: + self.with_pool = True + self.layer_norm_conds = layer_norm_cond + self.layer_add_embs = layer_add_embs + self.conditional_size = ( + layer_norm_cond.weight.size(1) if layer_norm_cond is not None else None + ) + self.embeddings = BertEmbeddings( + self.vocab_size, + self.embedding_size, + self.hidden_size, + self.max_position, + self.segment_vocab_size, + self.shared_segment_embeddings, + self.dropout_rate, + self.conditional_size, + **get_kw(BertEmbeddings, kwargs), + ) + kwargs["max_position"] = self.max_position + layer = BertLayer( + self.hidden_size, + self.num_attention_heads, + self.dropout_rate, + self.attention_probs_dropout_prob, + self.intermediate_size, + self.hidden_act, + is_dropout=self.is_dropout, + conditional_size=self.conditional_size, + **get_kw(BertLayer, kwargs), + ) + self.encoderLayer = nn.ModuleList( + [ + copy.deepcopy(layer) + if layer_id in self.keep_hidden_layers + else Identity() + for layer_id in range(self.num_hidden_layers) + ] + ) + if self.with_pool: + + self.pooler = nn.Linear(self.hidden_size, self.hidden_size) + self.pooler_activation = ( + nn.Tanh() if self.with_pool is True else get_activation(self.with_pool) + ) + if self.with_nsp: + + self.nsp = nn.Linear(self.hidden_size, 2) + else: + self.pooler = None + self.pooler_activation = None + if self.with_mlm: + self.mlmDense = nn.Linear(self.hidden_size, self.hidden_size) + self.transform_act_fn = get_activation(self.hidden_act) + self.mlmLayerNorm = LayerNorm( + self.hidden_size, eps=1e-12, conditional_size=self.conditional_size + ) + self.mlmDecoder = nn.Linear(self.hidden_size, self.vocab_size, bias=False) + if kwargs.get("tie_emb_prj_weight") is True: + self.mlmDecoder.weight = self.embeddings.word_embeddings.weight + self.mlmBias = nn.Parameter(torch.zeros(self.vocab_size)) + self.mlmDecoder.bias = self.mlmBias + + + def apply_embeddings(self, inputs): + token_ids = inputs[0] + index_ = 1 + if self.segment_vocab_size > 0: + segment_ids = inputs[index_] + index_ += 1 + else: + segment_ids = None + + if self.custom_position_ids: + position_ids = inputs[index_] + index_ += 1 + else: + position_ids = None + + if self.custom_attention_mask: + attention_mask = inputs[index_].long().unsqueeze(1).unsqueeze(2) + index_ += 1 + elif (not token_ids.requires_grad) and ( + token_ids.dtype in {torch.long, torch.int} + ): + attention_mask = ( + (token_ids != self.token_pad_ids).long().unsqueeze(1).unsqueeze(2) + ) + if self.token_pad_ids < 0: + token_ids = token_ids * attention_mask[:, 0, 0, :] + else: + attention_mask = self.attention_mask_cache + self.attention_mask_cache = attention_mask + + self.compute_attention_bias([token_ids, segment_ids]) + if self.attention_bias is not None: + attention_mask = attention_mask * self.attention_bias + + try: + attention_mask = attention_mask.to( + dtype=next(self.parameters()).dtype + ) + except StopIteration: + attention_mask = attention_mask.to(dtype=torch.float32) + + if self.layer_norm_conds is None: + conditional_emb = None + else: + conditional_emb = self.layer_norm_conds(inputs[index_]) + index_ += 1 + + + if isinstance(self.layer_add_embs, nn.Module): + additional_embs = [self.layer_add_embs(inputs[index_])] + index_ += 1 + elif isinstance(self.layer_add_embs, (tuple, list)): + additional_embs = [] + for layer in self.layer_add_embs: + assert isinstance( + layer, nn.Module + ), "Layer_add_embs element should be nn.Module" + additional_embs.append(layer(inputs[index_])) + index_ += 1 + else: + additional_embs = None + + + hidden_states = self.embeddings( + token_ids, segment_ids, conditional_emb, additional_embs + ) + return [hidden_states, attention_mask, conditional_emb] + inputs[index_:] + + def apply_main_layers(self, inputs): + hidden_states, attention_mask, conditional_emb = inputs[:3] + if len(inputs[3:]) >= 2: + encoder_hidden_state, encoder_attention_mask = inputs[3], inputs[4] + else: + encoder_hidden_state, encoder_attention_mask = None, None + + encoded_layers = [hidden_states] + layer_inputs = [ + hidden_states, + attention_mask, + conditional_emb, + encoder_hidden_state, + encoder_attention_mask, + ] + for l_i, layer_module in enumerate(self.encoderLayer): + layer_inputs = self.apply_on_layer_begin(l_i, layer_inputs) + hidden_states = layer_module(*layer_inputs) + layer_inputs[0] = hidden_states + layer_inputs = self.apply_on_layer_end(l_i, layer_inputs) + + if self.output_all_encoded_layers: + encoded_layers.append(hidden_states) + if not self.output_all_encoded_layers: + encoded_layers.append(hidden_states) + return [encoded_layers, conditional_emb] + + def apply_final_layers(self, inputs): + encoded_layers, conditional_emb = inputs + sequence_output = encoded_layers[-1] + + if not self.output_all_encoded_layers: + encoded_layers = encoded_layers[-1] + + + if self.with_pool: + pooled_output = self.pooler_activation(self.pooler(sequence_output[:, 0])) + else: + pooled_output = None + + if self.with_pool and self.with_nsp: + nsp_scores = self.nsp(pooled_output) + else: + nsp_scores = None + + if self.with_mlm: + mlm_hidden_state = self.mlmDense(sequence_output) + mlm_hidden_state = self.transform_act_fn(mlm_hidden_state) + mlm_hidden_state = self.mlmLayerNorm((mlm_hidden_state, conditional_emb)) + mlm_scores = self.mlmDecoder(mlm_hidden_state) + mlm_activation = get_activation( + "linear" if self.with_mlm is True else self.with_mlm + ) + mlm_scores = mlm_activation(mlm_scores) + else: + mlm_scores = None + + outputs = [ + value + for value in [encoded_layers, pooled_output, mlm_scores, nsp_scores] + if value is not None + ] + return outputs if len(outputs) > 1 else outputs[0] + + def load_variable(self, state_dict, name, prefix="bert"): + variable = state_dict[name] + if name in { + f"{prefix}.embeddings.word_embeddings.weight", + "cls.predictions.bias", + "cls.predictions.decoder.weight", + "cls.predictions.decoder.bias", + }: + return self.load_embeddings(variable) + elif name == f"{prefix}.embeddings.position_embeddings.weight": + return self.load_pos_embeddings(variable) + elif name == "cls.seq_relationship.weight": + return variable.T + else: + return variable + + def variable_mapping(self, prefix="bert"): + mapping = { + "embeddings.word_embeddings.weight": f"{prefix}.embeddings.word_embeddings.weight", + "embeddings.position_embeddings.weight": f"{prefix}.embeddings.position_embeddings.weight", + "embeddings.segment_embeddings.weight": f"{prefix}.embeddings.token_type_embeddings.weight", + "embeddings.layerNorm.weight": f"{prefix}.embeddings.LayerNorm.weight", + "embeddings.layerNorm.bias": f"{prefix}.embeddings.LayerNorm.bias", + "pooler.weight": f"{prefix}.pooler.dense.weight", + "pooler.bias": f"{prefix}.pooler.dense.bias", + "nsp.weight": "cls.seq_relationship.weight", + "nsp.bias": "cls.seq_relationship.bias", + "mlmDense.weight": "cls.predictions.transform.dense.weight", + "mlmDense.bias": "cls.predictions.transform.dense.bias", + "mlmLayerNorm.weight": "cls.predictions.transform.LayerNorm.weight", + "mlmLayerNorm.bias": "cls.predictions.transform.LayerNorm.bias", + "mlmBias": "cls.predictions.bias", + "mlmDecoder.weight": "cls.predictions.decoder.weight", + "mlmDecoder.bias": "cls.predictions.decoder.bias", + } + for i in range(self.num_hidden_layers): + prefix_i = f"{prefix}.encoder.layer.%d." % i + mapping.update( + { + f"encoderLayer.{i}.multiHeadAttention.q.weight": prefix_i + + "attention.self.query.weight", + f"encoderLayer.{i}.multiHeadAttention.q.bias": prefix_i + + "attention.self.query.bias", + f"encoderLayer.{i}.multiHeadAttention.k.weight": prefix_i + + "attention.self.key.weight", + f"encoderLayer.{i}.multiHeadAttention.k.bias": prefix_i + + "attention.self.key.bias", + f"encoderLayer.{i}.multiHeadAttention.v.weight": prefix_i + + "attention.self.value.weight", + f"encoderLayer.{i}.multiHeadAttention.v.bias": prefix_i + + "attention.self.value.bias", + f"encoderLayer.{i}.multiHeadAttention.o.weight": prefix_i + + "attention.output.dense.weight", + f"encoderLayer.{i}.multiHeadAttention.o.bias": prefix_i + + "attention.output.dense.bias", + f"encoderLayer.{i}.layerNorm1.weight": prefix_i + + "attention.output.LayerNorm.weight", + f"encoderLayer.{i}.layerNorm1.bias": prefix_i + + "attention.output.LayerNorm.bias", + f"encoderLayer.{i}.feedForward.intermediateDense.weight": prefix_i + + "intermediate.dense.weight", + f"encoderLayer.{i}.feedForward.intermediateDense.bias": prefix_i + + "intermediate.dense.bias", + f"encoderLayer.{i}.feedForward.outputDense.weight": prefix_i + + "output.dense.weight", + f"encoderLayer.{i}.feedForward.outputDense.bias": prefix_i + + "output.dense.bias", + f"encoderLayer.{i}.layerNorm2.weight": prefix_i + + "output.LayerNorm.weight", + f"encoderLayer.{i}.layerNorm2.bias": prefix_i + + "output.LayerNorm.bias", + } + ) + + return mapping + + +class ALBERT(BERT): + def __init__(self, *args, **kwargs): + super(ALBERT, self).__init__(*args, **kwargs) + self.encoderLayer = nn.ModuleList([self.encoderLayer[0]]) + + def apply_main_layers(self, inputs): + hidden_states, attention_mask, conditional_emb = inputs[:3] + if len(inputs[3:]) >= 2: + encoder_hidden_state, encoder_attention_mask = inputs[3], inputs[4] + else: + encoder_hidden_state, encoder_attention_mask = None, None + + encoded_layers = [hidden_states] + layer_inputs = [ + hidden_states, + attention_mask, + conditional_emb, + encoder_hidden_state, + encoder_attention_mask, + ] + for l_i in range(self.num_hidden_layers): + layer_inputs = self.apply_on_layer_begin(l_i, layer_inputs) + hidden_states = self.encoderLayer[0](*layer_inputs) + layer_inputs[0] = hidden_states + layer_inputs = self.apply_on_layer_end(l_i, layer_inputs) + + if self.output_all_encoded_layers: + encoded_layers.append(hidden_states) + if not self.output_all_encoded_layers: + encoded_layers.append(hidden_states) + return [encoded_layers, conditional_emb] + + def variable_mapping(self, prefix="albert"): + mapping = { + "embeddings.word_embeddings.weight": f"{prefix}.embeddings.word_embeddings.weight", + "embeddings.position_embeddings.weight": f"{prefix}.embeddings.position_embeddings.weight", + "embeddings.segment_embeddings.weight": f"{prefix}.embeddings.token_type_embeddings.weight", + "embeddings.layerNorm.weight": f"{prefix}.embeddings.LayerNorm.weight", + "embeddings.layerNorm.bias": f"{prefix}.embeddings.LayerNorm.bias", + "embeddings.embedding_hidden_mapping_in.weight": f"{prefix}.encoder.embedding_hidden_mapping_in.weight", + "embeddings.embedding_hidden_mapping_in.bias": f"{prefix}.encoder.embedding_hidden_mapping_in.bias", + "pooler.weight": f"{prefix}.pooler.weight", + "pooler.bias": f"{prefix}.pooler.bias", + "nsp.weight": "sop_classifier.classifier.weight", + "nsp.bias": "sop_classifier.classifier.bias", + "mlmDense.weight": "predictions.dense.weight", + "mlmDense.bias": "predictions.dense.bias", + "mlmLayerNorm.weight": "predictions.LayerNorm.weight", + "mlmLayerNorm.bias": "predictions.LayerNorm.bias", + "mlmBias": "predictions.bias", + "mlmDecoder.weight": "predictions.decoder.weight", + "mlmDecoder.bias": "predictions.decoder.bias", + } + i = 0 + prefix_i = f"{prefix}.encoder.albert_layer_groups.{i}.albert_layers.{i}." + mapping.update( + { + f"encoderLayer.{i}.multiHeadAttention.q.weight": prefix_i + + "attention.query.weight", + f"encoderLayer.{i}.multiHeadAttention.q.bias": prefix_i + + "attention.query.bias", + f"encoderLayer.{i}.multiHeadAttention.k.weight": prefix_i + + "attention.key.weight", + f"encoderLayer.{i}.multiHeadAttention.k.bias": prefix_i + + "attention.key.bias", + f"encoderLayer.{i}.multiHeadAttention.v.weight": prefix_i + + "attention.value.weight", + f"encoderLayer.{i}.multiHeadAttention.v.bias": prefix_i + + "attention.value.bias", + f"encoderLayer.{i}.multiHeadAttention.o.weight": prefix_i + + "attention.dense.weight", + f"encoderLayer.{i}.multiHeadAttention.o.bias": prefix_i + + "attention.dense.bias", + f"encoderLayer.{i}.layerNorm1.weight": prefix_i + + "attention.LayerNorm.weight", + f"encoderLayer.{i}.layerNorm1.bias": prefix_i + + "attention.LayerNorm.bias", + f"encoderLayer.{i}.feedForward.intermediateDense.weight": prefix_i + + "ffn.weight", + f"encoderLayer.{i}.feedForward.intermediateDense.bias": prefix_i + + "ffn.bias", + f"encoderLayer.{i}.feedForward.outputDense.weight": prefix_i + + "ffn_output.weight", + f"encoderLayer.{i}.feedForward.outputDense.bias": prefix_i + + "ffn_output.bias", + f"encoderLayer.{i}.layerNorm2.weight": prefix_i + + "full_layer_layer_norm.weight", + f"encoderLayer.{i}.layerNorm2.bias": prefix_i + + "full_layer_layer_norm.bias", + } + ) + + return mapping + + def load_variable(self, state_dict, name): + + variable = state_dict[name] + if name in { + "albert.embeddings.word_embeddings.weight", + "predictions.bias", + "predictions.decoder.weight", + "predictions.decoder.bias", + }: + return self.load_embeddings(variable) + elif name == "albert.embeddings.position_embeddings.weight": + return self.load_pos_embeddings(variable) + elif name == "sop_classifier.classifier.weight": + return variable.T + else: + return variable + + +class ALBERT_Unshared(ALBERT): + def __init__(self, *args, **kwargs): + super(ALBERT_Unshared).__init__(*args, **kwargs) + self.encoderLayer = nn.ModuleList( + [copy.deepcopy(self.encoderLayer[0]) for _ in range(self.num_hidden_layers)] + ) + + def apply_main_layers(self, inputs): + + hidden_states, attention_mask, conditional_emb = inputs + if len(inputs[3:]) >= 2: + encoder_hidden_state, encoder_attention_mask = inputs[3], inputs[4] + else: + encoder_hidden_state, encoder_attention_mask = None, None + + encoded_layers = [hidden_states] # 添加embedding的输出 + layer_inputs = [ + hidden_states, + attention_mask, + conditional_emb, + encoder_hidden_state, + encoder_attention_mask, + ] + for i in range(self.num_hidden_layers): + layer_inputs = self.apply_on_layer_begin(i, layer_inputs) + hidden_states = self.encoderLayer[i](*layer_inputs) + layer_inputs[0] = hidden_states + layer_inputs = self.apply_on_layer_end(i, layer_inputs) + + if self.output_all_encoded_layers: + encoded_layers.append(hidden_states) + if not self.output_all_encoded_layers: + encoded_layers.append(hidden_states) + return [encoded_layers, conditional_emb] + + +class NEZHA(BERT): + def __init__(self, *args, **kwargs): + + kwargs.update( + { + "p_bias": "typical_relative", + "max_relative_position": kwargs.get("max_relative_position", 64), + } + ) + super(NEZHA, self).__init__(*args, **kwargs) + + +class RoFormer(BERT): + def __init__(self, *args, **kwargs): + kwargs.update({"p_bias": "rotary"}) + super(RoFormer, self).__init__(*args, **kwargs) + + def load_variable(self, state_dict, name, prefix="roformer"): + return super().load_variable(state_dict, name, prefix) + + def variable_mapping(self, prefix="roformer"): + mapping = super().variable_mapping(prefix) + del mapping["embeddings.position_embeddings.weight"] + return mapping + + +class RoFormerV2(RoFormer): + @delete_arguments("with_pool", "with_nsp") + def __init__(self, *args, **kwargs): + kwargs.update( + {"p_bias": "rotary", "weight": False, "bias": False, "norm_mode": "rmsnorm"} + ) + super(RoFormerV2, self).__init__(*args, **kwargs) + if self.with_mlm: + del self.mlmLayerNorm + del self.mlmBias + del self.mlmDense + self.mlmDecoder.register_parameter("bias", None) + + def variable_mapping(self, prefix="roformer"): + mapping = super().variable_mapping(prefix) + mapping_new = {} + for k, v in mapping.items(): + if (not re.search("bias|layernorm", k.lower())) and ( + not re.search("bias|layernorm", v.lower()) + ): + mapping_new[k] = v + return mapping_new + + def apply_final_layers(self, inputs): + encoded_layers, conditional_emb = inputs + sequence_output = encoded_layers[-1] + if not self.output_all_encoded_layers: + encoded_layers = encoded_layers[-1] + + if self.with_mlm: + mlm_scores = self.mlmDecoder(sequence_output) + else: + mlm_scores = None + + outputs = [value for value in [encoded_layers, mlm_scores] if value is not None] + return outputs if len(outputs) > 1 else outputs[0] + + +class GAU_alpha(RoFormerV2): + def __init__(self, *args, **kwargs): + kwargs.update( + { + "p_bias": "rotary", + "weight": False, + "bias": False, + "norm_mode": "rmsnorm", + "normalization": "softmax_plus", + } + ) + super().__init__(*args, **kwargs) + + layer = self.GAU_Layer(**kwargs) + self.encoderLayer = nn.ModuleList( + [ + copy.deepcopy(layer) + if layer_id in self.keep_hidden_layers + else Identity() + for layer_id in range(self.num_hidden_layers) + ] + ) + + def load_variable(self, state_dict, name, prefix=""): + variable = state_dict[name] + return ( + self.load_embeddings(variable) + if name in {"embeddings.word_embeddings.weight", "mlmDecoder.weight"} + else variable + ) + + def variable_mapping(self, prefix=""): + return {k: k for k, _ in self.named_parameters()} + + class GAU_Layer(nn.Module): + def __init__(self, *args, **kwargs): + super().__init__() + self.gau = GatedAttentionUnit(**kwargs) + self.dropout1 = nn.Dropout(kwargs.get("dropout_rate")) + self.layerNorm1 = LayerNorm(**kwargs) + + def forward( + self, + hidden_states, + attention_mask, + conditional_emb=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + ): + gau_hidden_states = self.gau(hidden_states, attention_mask) + hidden_states = hidden_states + self.dropout1(gau_hidden_states) + hidden_states = self.layerNorm1((hidden_states, conditional_emb)) + return hidden_states + + +class ELECTRA(BERT): + @insert_arguments(with_discriminator=False) + @delete_arguments("with_pool", "with_mlm", "with_nsp") + def __init__(self, max_position, **kwargs): + super(ELECTRA, self).__init__(max_position, **kwargs) + if self.with_discriminator: + self.dense = nn.Linear(self.hidden_size, self.hidden_size) + self.dense_act = get_activation(self.hidden_act) + self.dense_prediction = nn.Linear(self.hidden_size, 1) + self.dense_prediction_act = ( + get_activation("sigmoid") + if self.with_discriminator is True + else get_activation(self.with_discriminator) + ) + + def apply_final_layers(self, inputs): + hidden_states = super().apply_final_layers(inputs) + if self.with_discriminator: + logit = self.dense_act(self.dense(hidden_states)) + return [ + hidden_states, + self.dense_prediction_act(self.dense_prediction(logit)), + ] + else: + return hidden_states + + def load_variable(self, state_dict, name): + return super().load_variable(state_dict, name, prefix="electra") + + def variable_mapping(self): + mapping = super(ELECTRA, self).variable_mapping(prefix="electra") + mapping.update( + { + "dense.weight": "discriminator_predictions.dense.weight", + "dense.bias": "discriminator_predictions.dense.bias", + "dense_prediction.weight": "discriminator_predictions.dense_prediction.weight", + "dense_prediction.bias": "discriminator_predictions.dense_prediction.bias", + } + ) + for del_key in [ + "pooler.weight", + "pooler.bias", + "nsp.weight", + "nsp.bias", + "mlmDense.weight", + "mlmDense.bias", + "mlmLayerNorm.weight", + "mlmLayerNorm.bias", + "mlmBias", + "mlmDecoder.weight", + "mlmDecoder.bias", + ]: + del mapping[del_key] + + return mapping + + +class Encoder(BERT): + def __init__(self, *args, **kwargs): + kwargs["vocab_size"] = kwargs.get("src_vocab_size", kwargs["vocab_size"]) + super().__init__(*args, **kwargs) + self.encoder_attention_mask = None + + def forward(self, inputs): + # Embedding + outputs = self.apply_embeddings(inputs) + encoder_attention_mask = [outputs[1]] + # Main + outputs = self.apply_main_layers(outputs) + # Final + outputs = self.apply_final_layers(outputs) + return ( + [outputs] if isinstance(outputs, torch.Tensor) else outputs + ) + encoder_attention_mask + + +class Decoder(LM_Mask, BERT): + @delete_arguments("with_pool", "with_mlm", "with_nsp") + def __init__(self, *args, with_lm=True, tie_emb_prj_weight=True, **kwargs): + kwargs["vocab_size"] = kwargs.get("tgt_vocab_size", kwargs["vocab_size"]) + kwargs["is_decoder"] = True + super().__init__(*args, **kwargs) + self.decoderLayer = self.encoderLayer + del self.encoderLayer + self.with_lm = with_lm + + if self.with_lm: + self.final_dense = nn.Linear(self.hidden_size, self.vocab_size, bias=False) + if tie_emb_prj_weight: + self.final_dense.weight = self.embeddings.word_embeddings.weight + self.x_logit_scale = self.hidden_size**-0.5 + else: + self.x_logit_scale = 1.0 + + def apply_main_layers(self, inputs): + ( + hidden_states, + attention_mask, + conditional_emb, + encoder_hidden_state, + encoder_attention_mask, + ) = inputs[:5] + decoded_layers = [hidden_states] + layer_inputs = [ + hidden_states, + attention_mask, + conditional_emb, + encoder_hidden_state, + encoder_attention_mask, + ] + for i, layer_module in enumerate(self.decoderLayer): + layer_inputs = self.apply_on_layer_begin(i, layer_inputs) + hidden_states = layer_module(*layer_inputs) + layer_inputs[0] = hidden_states + layer_inputs = self.apply_on_layer_end(i, layer_inputs) + + if self.output_all_encoded_layers: + decoded_layers.append(hidden_states) + if not self.output_all_encoded_layers: + decoded_layers.append(hidden_states) + return [decoded_layers, conditional_emb] + + def apply_final_layers(self, inputs): + outputs = [] + hidden_states = super().apply_final_layers( + inputs + ) + outputs.append(hidden_states) + if self.with_lm: + logits = ( + self.final_dense(hidden_states) * self.x_logit_scale + ) + activation = get_activation( + "linear" if self.with_lm is True else self.with_lm + ) + logits = activation(logits) + outputs.append(logits) + return outputs + + def variable_mapping(self, prefix="bert"): + raw_mapping = super().variable_mapping(prefix) + mapping = {} + for k, v in raw_mapping.items(): + mapping[k.replace("encoderLayer", "decoderLayer")] = v + return mapping + + +class Transformer(BERT_BASE): + """encoder-decoder结构""" + + @delete_arguments("with_pool", "with_mlm", "with_nsp") + def __init__(self, *args, tie_emb_src_tgt_weight=False, **kwargs): + super(Transformer, self).__init__(*args, **kwargs) + + # encoder + self.encoder = Encoder(*args, **kwargs) + self.encoder.build(**kwargs) + + # decoder + self.decoder = Decoder(*args, **kwargs) + self.decoder.build(**kwargs) + + if tie_emb_src_tgt_weight: + assert ( + self.encoder.vocab_size == self.decoder.vocab_size + ), "To share word embedding, the vocab size of src/tgt shall be the same." + self.encoder.embeddings.word_embeddings.weight = ( + self.decoder.embeddings.word_embeddings.weight + ) + + def forward(self, inputs): + encoder_input, decoder_input = inputs[:2] + + # encoder + # encoder_emb = self.encoder.apply_embeddings(encoder_input) + # encode_outputs = self.encoder.apply_main_layers(encoder_emb) + # encoder_hidden_state = self.encoder.apply_final_layers(encode_outputs) + # encoder_attention_mask = encoder_emb[1] + encoder_hidden_state, encoder_attention_mask = self.encoder(encoder_input) + + # decoder + # decoder_emb = self.decoder.apply_embeddings(decoder_input) + # decoder_outputs = self.decoder.apply_main_layers([*decoder_emb, encoder_hidden_state, encoder_attention_mask]) + # decoder_outputs = self.decoder.apply_final_layers(decoder_outputs) # [hidden_states, logits] + decoder_outputs = self.decoder( + decoder_input + [encoder_hidden_state, encoder_attention_mask] + ) + return [ + encoder_hidden_state + ] + decoder_outputs + + +class BART(Transformer): + """encoder-decoder结构""" + + def __init__(self, *args, tie_emb_src_tgt_weight=True, **kwargs): + super(BART, self).__init__( + *args, tie_emb_src_tgt_weight=tie_emb_src_tgt_weight, **kwargs + ) + self.tie_emb_src_tgt_weight = tie_emb_src_tgt_weight + + def load_variable(self, state_dict, name, prefix=""): + variable = state_dict[name] + if name in { + "shared.weight", + "encoder.embed_tokens.weight", + "decoder.embed_tokens.weight", + }: + return self.load_embeddings(variable) + elif name in { + "encoder.embed_positions.weight", + "decoder.embed_positions.weight", + }: + return self.load_pos_embeddings(variable) + else: + return variable + + def variable_mapping(self, prefix=""): + mapping = { + "encoder.embeddings.word_embeddings.weight": "shared.weight" + if self.tie_emb_src_tgt_weight + else "encoder.embed_tokens.weight", + "encoder.embeddings.position_embeddings.weight": "encoder.embed_positions.weight", + "encoder.embeddings.layerNorm.weight": "encoder.layernorm_embedding.weight", + "encoder.embeddings.layerNorm.bias": "encoder.layernorm_embedding.bias", + "decoder.embeddings.word_embeddings.weight": "shared.weight" + if self.tie_emb_src_tgt_weight + else "decoder.embed_tokens.weight", + "decoder.embeddings.position_embeddings.weight": "decoder.embed_positions.weight", + "decoder.embeddings.layerNorm.weight": "decoder.layernorm_embedding.weight", + "decoder.embeddings.layerNorm.bias": "decoder.layernorm_embedding.bias", + } + for i in range(self.num_hidden_layers): + mapping.update( + { + f"encoder.encoderLayer.{i}.multiHeadAttention.q.weight": f"encoder.layers.{i}.self_attn.q_proj.weight", + f"encoder.encoderLayer.{i}.multiHeadAttention.q.bias": f"encoder.layers.{i}.self_attn.q_proj.bias", + f"encoder.encoderLayer.{i}.multiHeadAttention.k.weight": f"encoder.layers.{i}.self_attn.k_proj.weight", + f"encoder.encoderLayer.{i}.multiHeadAttention.k.bias": f"encoder.layers.{i}.self_attn.k_proj.bias", + f"encoder.encoderLayer.{i}.multiHeadAttention.v.weight": f"encoder.layers.{i}.self_attn.v_proj.weight", + f"encoder.encoderLayer.{i}.multiHeadAttention.v.bias": f"encoder.layers.{i}.self_attn.v_proj.bias", + f"encoder.encoderLayer.{i}.multiHeadAttention.o.weight": f"encoder.layers.{i}.self_attn.out_proj.weight", + f"encoder.encoderLayer.{i}.multiHeadAttention.o.bias": f"encoder.layers.{i}.self_attn.out_proj.bias", + f"encoder.encoderLayer.{i}.layerNorm1.weight": f"encoder.layers.{i}.self_attn_layer_norm.weight", + f"encoder.encoderLayer.{i}.layerNorm1.bias": f"encoder.layers.{i}.self_attn_layer_norm.bias", + f"encoder.encoderLayer.{i}.feedForward.intermediateDense.weight": f"encoder.layers.{i}.fc1.weight", + f"encoder.encoderLayer.{i}.feedForward.intermediateDense.bias": f"encoder.layers.{i}.fc1.bias", + f"encoder.encoderLayer.{i}.feedForward.outputDense.weight": f"encoder.layers.{i}.fc2.weight", + f"encoder.encoderLayer.{i}.feedForward.outputDense.bias": f"encoder.layers.{i}.fc2.bias", + f"encoder.encoderLayer.{i}.layerNorm2.weight": f"encoder.layers.{i}.final_layer_norm.weight", + f"encoder.encoderLayer.{i}.layerNorm2.bias": f"encoder.layers.{i}.final_layer_norm.bias", + f"decoder.decoderLayer.{i}.multiHeadAttention.q.weight": f"decoder.layers.{i}.self_attn.q_proj.weight", + f"decoder.decoderLayer.{i}.multiHeadAttention.q.bias": f"decoder.layers.{i}.self_attn.q_proj.bias", + f"decoder.decoderLayer.{i}.multiHeadAttention.k.weight": f"decoder.layers.{i}.self_attn.k_proj.weight", + f"decoder.decoderLayer.{i}.multiHeadAttention.k.bias": f"decoder.layers.{i}.self_attn.k_proj.bias", + f"decoder.decoderLayer.{i}.multiHeadAttention.v.weight": f"decoder.layers.{i}.self_attn.v_proj.weight", + f"decoder.decoderLayer.{i}.multiHeadAttention.v.bias": f"decoder.layers.{i}.self_attn.v_proj.bias", + f"decoder.decoderLayer.{i}.multiHeadAttention.o.weight": f"decoder.layers.{i}.self_attn.out_proj.weight", + f"decoder.decoderLayer.{i}.multiHeadAttention.o.bias": f"decoder.layers.{i}.self_attn.out_proj.bias", + f"decoder.decoderLayer.{i}.layerNorm1.weight": f"decoder.layers.{i}.self_attn_layer_norm.weight", + f"decoder.decoderLayer.{i}.layerNorm1.bias": f"decoder.layers.{i}.self_attn_layer_norm.bias", + f"decoder.decoderLayer.{i}.crossAttention.q.weight": f"decoder.layers.{i}.encoder_attn.q_proj.weight", + f"decoder.decoderLayer.{i}.crossAttention.q.bias": f"decoder.layers.{i}.encoder_attn.q_proj.bias", + f"decoder.decoderLayer.{i}.crossAttention.k.weight": f"decoder.layers.{i}.encoder_attn.k_proj.weight", + f"decoder.decoderLayer.{i}.crossAttention.k.bias": f"decoder.layers.{i}.encoder_attn.k_proj.bias", + f"decoder.decoderLayer.{i}.crossAttention.v.weight": f"decoder.layers.{i}.encoder_attn.v_proj.weight", + f"decoder.decoderLayer.{i}.crossAttention.v.bias": f"decoder.layers.{i}.encoder_attn.v_proj.bias", + f"decoder.decoderLayer.{i}.crossAttention.o.weight": f"decoder.layers.{i}.encoder_attn.out_proj.weight", + f"decoder.decoderLayer.{i}.crossAttention.o.bias": f"decoder.layers.{i}.encoder_attn.out_proj.bias", + f"decoder.decoderLayer.{i}.layerNorm3.weight": f"decoder.layers.{i}.encoder_attn_layer_norm.weight", + f"decoder.decoderLayer.{i}.layerNorm3.bias": f"decoder.layers.{i}.encoder_attn_layer_norm.bias", + f"decoder.decoderLayer.{i}.feedForward.intermediateDense.weight": f"decoder.layers.{i}.fc1.weight", + f"decoder.decoderLayer.{i}.feedForward.intermediateDense.bias": f"decoder.layers.{i}.fc1.bias", + f"decoder.decoderLayer.{i}.feedForward.outputDense.weight": f"decoder.layers.{i}.fc2.weight", + f"decoder.decoderLayer.{i}.feedForward.outputDense.bias": f"decoder.layers.{i}.fc2.bias", + f"decoder.decoderLayer.{i}.layerNorm2.weight": f"decoder.layers.{i}.final_layer_norm.weight", + f"decoder.decoderLayer.{i}.layerNorm2.bias": f"decoder.layers.{i}.final_layer_norm.bias", + } + ) + + return mapping + + +class T5_Encoder(Encoder): + @insert_arguments(version="t5.1.0") + def __init__(self, *args, **kwargs): + kwargs.update( + { + "p_bias": "t5_relative", + "relative_attention_num_buckets": kwargs.get( + "relative_attention_num_buckets" + ), + "version": self.version, + "bias": False, + "norm_mode": "rmsnorm", + } + ) + super().__init__(*args, **kwargs) + del self.embeddings.layerNorm + + layer = T5Layer( + self.hidden_size, + self.num_attention_heads, + self.dropout_rate, + self.attention_probs_dropout_prob, + self.intermediate_size, + self.hidden_act, + is_dropout=self.is_dropout, + conditional_size=self.conditional_size, + **get_kw(BertLayer, kwargs), + ) + self.encoderLayer = nn.ModuleList( + [copy.deepcopy(layer) for _ in range(self.num_hidden_layers)] + ) + + for i in range(1, self.num_hidden_layers): + self.encoderLayer[ + i + ].multiHeadAttention.relative_positions_encoding.weight = self.encoderLayer[ + 0 + ].multiHeadAttention.relative_positions_encoding.weight + self.final_layer_norm = LayerNorm( + self.hidden_size, + eps=1e-12, + conditional_size=self.conditional_size, + bias=False, + mode="rmsnorm", + ) + self.dropout = nn.Dropout(self.dropout_rate) + + def apply_final_layers(self, inputs): + hidden_states = super().apply_final_layers(inputs) + return self.dropout(self.final_layer_norm([hidden_states])) + + def load_variable(self, state_dict, name, prefix=""): + variable = state_dict[name] + if name in {"encoder.embed_tokens.weight", "shared.weight"}: + return self.load_embeddings(variable) + else: + return variable + + def variable_mapping(self, prefix=""): + mapping = { + f"{prefix}embeddings.word_embeddings.weight": "encoder.embed_tokens.weight", + f"{prefix}encoderLayer.0.multiHeadAttention.relative_positions_encoding.weight": "encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight", + f"{prefix}final_layer_norm.weight": "encoder.final_layer_norm.weight", + } + for i in range(self.num_hidden_layers): + mapping.update( + { + f"{prefix}encoderLayer.{i}.multiHeadAttention.q.weight": f"encoder.block.{i}.layer.0.SelfAttention.q.weight", + f"{prefix}encoderLayer.{i}.multiHeadAttention.k.weight": f"encoder.block.{i}.layer.0.SelfAttention.k.weight", + f"{prefix}encoderLayer.{i}.multiHeadAttention.v.weight": f"encoder.block.{i}.layer.0.SelfAttention.v.weight", + f"{prefix}encoderLayer.{i}.multiHeadAttention.o.weight": f"encoder.block.{i}.layer.0.SelfAttention.o.weight", + f"{prefix}encoderLayer.{i}.layerNorm1.weight": f"encoder.block.{i}.layer.0.layer_norm.weight", + f"{prefix}encoderLayer.{i}.feedForward.outputDense.weight": f"encoder.block.{i}.layer.1.DenseReluDense.wo.weight", + f"{prefix}encoderLayer.{i}.layerNorm2.weight": f"encoder.block.{i}.layer.1.layer_norm.weight", + } + ) + + if self.version.endswith("t5.1.0"): + mapping.update( + { + f"{prefix}encoderLayer.{i}.feedForward.intermediateDense.weight": f"encoder.block.{i}.layer.1.DenseReluDense.wi.weight" + } + ) + elif self.version.endswith("t5.1.1"): + mapping.update( + { + f"{prefix}encoderLayer.{i}.feedForward.intermediateDense.weight": f"encoder.block.{i}.layer.1.DenseReluDense.wi_0.weight", + f"{prefix}encoderLayer.{i}.feedForward.intermediateDense1.weight": f"encoder.block.{i}.layer.1.DenseReluDense.wi_1.weight", + } + ) + return mapping + + +class T5_Decoder(Decoder): + @insert_arguments(version="t5.1.0") + def __init__(self, *args, **kwargs): + kwargs.update( + { + "p_bias": "t5_relative", + "relative_attention_num_buckets": kwargs.get( + "relative_attention_num_buckets" + ), + "version": self.version, + "bias": False, + "norm_mode": "rmsnorm", + } + ) + super().__init__(*args, **kwargs) + del self.embeddings.layerNorm + + layer = T5Layer( + self.hidden_size, + self.num_attention_heads, + self.dropout_rate, + self.attention_probs_dropout_prob, + self.intermediate_size, + self.hidden_act, + is_dropout=self.is_dropout, + conditional_size=self.conditional_size, + is_decoder=True, + **get_kw(BertLayer, kwargs), + ) + self.decoderLayer = nn.ModuleList( + [copy.deepcopy(layer) for _ in range(self.num_hidden_layers)] + ) + + for i in range(1, self.num_hidden_layers): + self.decoderLayer[ + i + ].multiHeadAttention.relative_positions_encoding.weight = self.decoderLayer[ + 0 + ].multiHeadAttention.relative_positions_encoding.weight + self.final_layer_norm = LayerNorm( + self.hidden_size, + eps=1e-12, + conditional_size=self.conditional_size, + bias=False, + mode="rmsnorm", + ) + self.dropout = nn.Dropout(self.dropout_rate) + + def apply_final_layers(self, inputs): + inputs[0][1] = self.dropout( + self.final_layer_norm([inputs[0][1]]) + ) + return super().apply_final_layers(inputs) + + def load_variable(self, state_dict, name, prefix=""): + variable = state_dict[name] + if name in {f"decoder.embed_tokens.weight", "lm_head.weight", "shared.weight"}: + return self.load_embeddings(variable) + else: + return variable + + def variable_mapping(self, prefix=""): + mapping = { + f"{prefix}embeddings.word_embeddings.weight": "decoder.embed_tokens.weight", + f"{prefix}decoderLayer.0.multiHeadAttention.relative_positions_encoding.weight": "decoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight", + f"{prefix}final_layer_norm.weight": "decoder.final_layer_norm.weight", + f"{prefix}final_dense.weight": "lm_head.weight", + } + + for i in range(self.num_hidden_layers): + mapping.update( + { + f"{prefix}decoderLayer.{i}.multiHeadAttention.q.weight": f"decoder.block.{i}.layer.0.SelfAttention.q.weight", + f"{prefix}decoderLayer.{i}.multiHeadAttention.k.weight": f"decoder.block.{i}.layer.0.SelfAttention.k.weight", + f"{prefix}decoderLayer.{i}.multiHeadAttention.v.weight": f"decoder.block.{i}.layer.0.SelfAttention.v.weight", + f"{prefix}decoderLayer.{i}.multiHeadAttention.o.weight": f"decoder.block.{i}.layer.0.SelfAttention.o.weight", + f"{prefix}decoderLayer.{i}.layerNorm1.weight": f"decoder.block.{i}.layer.0.layer_norm.weight", + f"{prefix}decoderLayer.{i}.crossAttention.q.weight": f"decoder.block.{i}.layer.1.EncDecAttention.q.weight", + f"{prefix}decoderLayer.{i}.crossAttention.k.weight": f"decoder.block.{i}.layer.1.EncDecAttention.k.weight", + f"{prefix}decoderLayer.{i}.crossAttention.v.weight": f"decoder.block.{i}.layer.1.EncDecAttention.v.weight", + f"{prefix}decoderLayer.{i}.crossAttention.o.weight": f"decoder.block.{i}.layer.1.EncDecAttention.o.weight", + f"{prefix}decoderLayer.{i}.layerNorm3.weight": f"decoder.block.{i}.layer.1.layer_norm.weight", + f"{prefix}decoderLayer.{i}.feedForward.outputDense.weight": f"decoder.block.{i}.layer.2.DenseReluDense.wo.weight", + f"{prefix}decoderLayer.{i}.layerNorm2.weight": f"decoder.block.{i}.layer.2.layer_norm.weight", + } + ) + + if self.version.endswith("t5.1.0"): + mapping.update( + { + f"{prefix}decoderLayer.{i}.feedForward.intermediateDense.weight": f"decoder.block.{i}.layer.2.DenseReluDense.wi.weight" + } + ) + elif self.version.endswith("t5.1.1"): + mapping.update( + { + f"{prefix}decoderLayer.{i}.feedForward.intermediateDense.weight": f"decoder.block.{i}.layer.2.DenseReluDense.wi_0.weight", + f"{prefix}decoderLayer.{i}.feedForward.intermediateDense1.weight": f"decoder.block.{i}.layer.2.DenseReluDense.wi_1.weight", + } + ) + return mapping + + +class T5(Transformer): + @delete_arguments("with_pool", "with_mlm", "with_nsp") + def __init__(self, *args, tie_emb_src_tgt_weight=True, **kwargs): + super(T5, self).__init__(*args, **kwargs) + self.tie_emb_src_tgt_weight = tie_emb_src_tgt_weight + + # encoder + self.encoder = T5_Encoder(*args, **kwargs) + self.encoder.build(**kwargs) + + # decoder + self.decoder = T5_Decoder(*args, **kwargs) + self.decoder.build(**kwargs) + + def load_variable(self, state_dict, name, prefix=""): + variable = state_dict[name] + if name in { + "shared.weight", + "encoder.embed_tokens.weight", + "decoder.embed_tokens.weight", + "lm_head.weight", + }: + return self.load_embeddings(variable) + else: + return variable + + def variable_mapping(self, prefix=""): + mapping = self.encoder.variable_mapping(prefix="encoder.") + mapping.update(self.decoder.variable_mapping(prefix="decoder.")) + if self.tie_emb_src_tgt_weight: + mapping.update( + { + "encoder.embeddings.word_embeddings.weight": "shared.weight", + "decoder.embeddings.word_embeddings.weight": "shared.weight", + } + ) + return mapping + + +class GPT(LM_Mask, BERT): + @insert_arguments(final_activation="softmax") + @delete_arguments("with_pool", "with_mlm", "with_nsp") + def __init__(self, max_position, **kwargs): + super(GPT, self).__init__(max_position, **kwargs) + del self.embeddings.layerNorm + self.dense = nn.Linear(self.hidden_size, self.vocab_size, bias=False) + self.dense.weight = self.embeddings.word_embeddings.weight + self.final_activation = get_activation(self.final_activation) + + def apply_final_layers(self, inputs): + hidden_state = super().apply_final_layers(inputs) + logit = self.dense(hidden_state) + return self.final_activation(logit) + + def load_variable(self, state_dict, name): + return super(GPT, self).load_variable(state_dict, name, prefix="gpt") + + def variable_mapping(self): + mapping = super(GPT, self).variable_mapping(prefix="gpt") + return mapping + + +class GPT2(LM_Mask, BERT): + + @insert_arguments(final_activation="softmax") + @delete_arguments("with_pool", "with_mlm", "with_nsp") + def __init__(self, max_position, **kwargs): + super(GPT2, self).__init__(max_position, **kwargs) + del self.embeddings.layerNorm + layer = self.Gpt2Layer( + self.hidden_size, + self.num_attention_heads, + self.dropout_rate, + self.attention_probs_dropout_prob, + self.intermediate_size, + self.hidden_act, + is_dropout=self.is_dropout, + conditional_size=self.conditional_size, + ) + self.encoderLayer = nn.ModuleList( + [ + copy.deepcopy(layer) + if layer_id in self.keep_hidden_layers + else Identity() + for layer_id in range(self.num_hidden_layers) + ] + ) + self.LayerNormFinal = LayerNorm( + self.hidden_size, eps=1e-12, conditional_size=self.conditional_size + ) + self.dense = nn.Linear(self.hidden_size, self.vocab_size, bias=False) + self.dense.weight = self.embeddings.word_embeddings.weight + self.final_activation = get_activation(self.final_activation) + + def apply_final_layers(self, inputs): + hidden_state = super().apply_final_layers(inputs) + logit = self.dense(self.LayerNormFinal([hidden_state])) + return self.final_activation(logit) + + def load_variable(self, state_dict, name): + return super(GPT2, self).load_variable(state_dict, name, prefix="gpt2") + + def variable_mapping(self): + mapping = super(GPT2, self).variable_mapping(prefix="gpt2") + mapping.update( + { + "LayerNormFinal.weight": "gpt2.LayerNormFinal.weight", + "LayerNormFinal.bias": "gpt2.LayerNormFinal.bias", + } + ) + return mapping + + class Gpt2Layer(BertLayer): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward( + self, + hidden_states, + attention_mask, + conditional_emb=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + ): + x = self.layerNorm1((hidden_states, conditional_emb)) + self_attn_output = self.multiHeadAttention(x, attention_mask) + hidden_states = hidden_states + self.dropout1(self_attn_output) + x = self.layerNorm2((hidden_states, conditional_emb)) + ffn_output = self.feedForward(x) + hidden_states = hidden_states + self.dropout2(ffn_output) + return hidden_states + + +class GPT2_ML(LM_Mask, BERT): + @insert_arguments(final_activation="softmax") + @delete_arguments("with_pool", "with_mlm", "with_nsp") + def __init__(self, max_position, **kwargs): + super().__init__(max_position, **kwargs) + layer = self.Gpt2MlLayer( + self.hidden_size, + self.num_attention_heads, + self.dropout_rate, + self.attention_probs_dropout_prob, + self.intermediate_size, + self.hidden_act, + is_dropout=self.is_dropout, + conditional_size=self.conditional_size, + ) + self.encoderLayer = nn.ModuleList( + [ + copy.deepcopy(layer) + if layer_id in self.keep_hidden_layers + else Identity() + for layer_id in range(self.num_hidden_layers) + ] + ) + self.dense = nn.Linear(self.hidden_size, self.vocab_size, bias=False) + self.dense.weight = self.embeddings.word_embeddings.weight + self.final_activation = get_activation(self.final_activation) + + def apply_final_layers(self, inputs): + hidden_state = super().apply_final_layers(inputs) + logit = self.dense(hidden_state) + return self.final_activation(logit) + + def load_variable(self, state_dict, name): + return super(GPT2_ML, self).load_variable(state_dict, name, prefix="gpt2_ml") + + def variable_mapping(self): + mapping = super(GPT2_ML, self).variable_mapping(prefix="gpt2_ml") + return mapping + + class Gpt2MlLayer(BertLayer): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward( + self, + hidden_states, + attention_mask, + conditional_emb=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + ): + self_attn_output = self.multiHeadAttention(hidden_states, attention_mask) + hidden_states = hidden_states + self.dropout1(self_attn_output) + x = self.layerNorm1((hidden_states, conditional_emb)) + ffn_output = self.feedForward(x) + hidden_states = hidden_states + self.dropout2(ffn_output) + hidden_states = self.layerNorm2((hidden_states, conditional_emb)) + return hidden_states + + +class Transformer_XL(BERT): + @delete_arguments("with_pool", "with_nsp", "with_mlm") + @insert_arguments(with_lm=False) + def __init__(self, *args, mem_len=0, same_length=False, clamp_len=-1, **kwargs): + # p_bias来控制embedding阶段无pos_embedding + kwargs.update({"p_bias": "other_relative"}) + super().__init__(*args, **kwargs) + self.mem_len, self.same_length, self.clamp_len = mem_len, same_length, clamp_len + self.attn_type = kwargs.get("attn_type", 0) + + # embedding + if kwargs.get("adaptive_embedding"): + cutoffs, div_val, sample_softmax = ( + kwargs.get("cutoffs", []), + kwargs.get("div_val", 1), + kwargs.get("sample_softmax", False), + ) + self.embeddings = AdaptiveEmbedding( + self.vocab_size, + self.embedding_size, + self.hidden_size, + cutoffs, + div_val, + sample_softmax, + **get_kw(AdaptiveEmbedding, kwargs), + ) + else: + self.embeddings = nn.Embedding(self.vocab_size, self.embedding_size) + self.pos_embeddings = XlnetPositionsEncoding(self.embedding_size) + self.dropout = nn.Dropout(self.dropout_rate) + + if not kwargs.get("untie_r"): + self.r_w_bias = nn.Parameter( + torch.FloatTensor(self.num_attention_heads, self.attention_head_size) + ) + self.r_r_bias = nn.Parameter( + torch.FloatTensor(self.num_attention_heads, self.attention_head_size) + ) + if self.segment_vocab_size > 0: + self.r_s_bias = nn.Parameter( + torch.FloatTensor( + self.num_attention_heads, self.attention_head_size + ) + ) + else: + self.r_w_bias, self.r_r_bias = None, None + self.r_s_bias = None + + # transformer block + layer = XlnetLayer( + self.hidden_size, + self.num_attention_heads, + self.dropout_rate, + self.attention_probs_dropout_prob, + self.intermediate_size, + self.hidden_act, + is_dropout=self.is_dropout, + conditional_size=self.conditional_size, + r_w_bias=self.r_w_bias, + r_r_bias=self.r_r_bias, + r_s_bias=None, + **get_kw(BertLayer, kwargs), + ) + self.encoderLayer = nn.ModuleList( + [ + copy.deepcopy(layer) + if layer_id in self.keep_hidden_layers + else Identity() + for layer_id in range(self.num_hidden_layers) + ] + ) + + # 映射 + if self.with_lm: + self.dense = nn.Linear(self.hidden_size, self.vocab_size, bias=True) + + def init_mems(self, bsz): + if isinstance(self.mem_len, (int, float)) and (self.mem_len > 0): + mems = [] + param = next(self.parameters()) + for _ in range(self.num_hidden_layers + 1): + empty = torch.zeros( + bsz, + self.mem_len, + self.hidden_size, + dtype=param.dtype, + device=param.device, + ) + mems.append(empty) + + return mems + else: + return None + + def _update_mems(self, hids, mlen, qlen): + # does not deal with None + if self.mems is None: + return None + # mems is not None + assert len(hids) == len(self.mems), "len(hids) != len(mems)" + # There are `mlen + qlen` steps that can be cached into mems + with torch.no_grad(): + new_mems = [] + end_idx = mlen + max(0, qlen) + beg_idx = max(0, end_idx - self.mem_len) + for i in range(len(hids)): + cat = torch.cat([self.mems[i], hids[i]], dim=1) + new_mems.append(cat[:, beg_idx:end_idx].detach()) + self.mems = new_mems + + def relative_positional_encoding(self, qlen, klen, device): + pos_seq = torch.arange(klen - 1, -1, -1.0, device=device, dtype=torch.long) + if self.clamp_len > 0: + pos_seq.clamp_(max=self.clamp_len) + pos_emb = self.dropout(self.pos_embeddings(pos_seq)) + return pos_emb + + def create_mask(self, word_emb, qlen, klen, mlen): + + if self.same_length: + all_ones = word_emb.new_ones(qlen, klen) + mask_len = klen - self.mem_len + mask_shift_len = qlen - mask_len if mask_len > 0 else qlen + attention_mask = ( + 1 + - ( + torch.triu(all_ones, 1 + mlen) + + torch.tril(all_ones, -mask_shift_len) + ).byte() + ) + else: + attention_mask = torch.tril( + word_emb.new_ones(qlen, klen), diagonal=mlen + ).byte() + attention_mask = attention_mask[None, None, :, :] + return attention_mask + + def apply_embeddings(self, inputs): + + self.mems = self.init_mems(inputs[0].size(0)) + + + word_emb = self.dropout(self.embeddings(inputs[0])) + index_ = 1 + btz, qlen = inputs[0].shape[:2] + mlen = self.mems[0].size(1) if self.mems is not None else 0 + klen = mlen + qlen + + pos_emb = self.relative_positional_encoding(qlen, klen, word_emb.device) + + if self.segment_vocab_size > 0: + segment_ids = inputs[index_] + if mlen > 0: + mem_pad = torch.zeros( + [btz, mlen], dtype=torch.long, device=word_emb.device + ) + cat_ids = torch.cat([mem_pad, segment_ids], dim=1) + else: + cat_ids = segment_ids + segment_ids = (segment_ids[:, :, None] != cat_ids[:, None]).long() + index_ += 1 + else: + segment_ids = None + + if self.attn_type in {"uni", 0}: + attention_mask = self.create_mask(word_emb, qlen, klen, mlen) + elif self.attn_type == "bi": + attention_mask = ( + (inputs[0] != self.token_pad_ids).long().unsqueeze(1).unsqueeze(2) + ) + non_tgt_mask = torch.eye(qlen).to(attention_mask)[None, None, :, :] + non_tgt_mask = ((1 - attention_mask - non_tgt_mask) <= 0).long() + + return [word_emb, segment_ids, pos_emb, non_tgt_mask, None] + + def apply_main_layers(self, inputs): + hidden_states, segment_ids, pos_emb, attention_mask, conditional_emb = inputs[ + :5 + ] + encoded_layers = [hidden_states] + + layer_inputs = [ + hidden_states, + segment_ids, + pos_emb, + attention_mask, + None, + conditional_emb, + ] + for i, layer_module in enumerate(self.encoderLayer): + mems_i = None if self.mems is None else self.mems[i] + layer_inputs[-2] = mems_i + layer_inputs = self.apply_on_layer_begin(i, layer_inputs) + hidden_states = layer_module(*layer_inputs) + layer_inputs[0] = hidden_states + layer_inputs = self.apply_on_layer_end(i, layer_inputs) + encoded_layers.append(hidden_states) + + hidden_states = self.dropout(hidden_states) + qlen = inputs[0].size(1) + mlen = self.mems[0].size(0) if self.mems is not None else 0 + self._update_mems(encoded_layers, mlen, qlen) + + if not self.output_all_encoded_layers: + encoded_layers = encoded_layers[:1] + [hidden_states] + return [encoded_layers, conditional_emb] + + def load_variable(self, state_dict, name, prefix=""): + if (self.keep_tokens is not None) or (self.compound_tokens is not None): + raise ValueError( + "Custom keep_tokens and compound_tokens is not yet supported in Transformer_XL" + ) + return state_dict[name] + + def variable_mapping(self, prefix=""): + return {k: k for k, v in self.named_parameters()} + + +class XLNET(Transformer_XL): + + def __init__(self, *args, bi_data=False, **kwargs): + self.attn_type = kwargs.get("attn_type", "bi") + self.bi_data = bi_data + kwargs["rel_shift_opt"] = "xlnet" + super().__init__(*args, **kwargs) + + def relative_positional_encoding(self, qlen, klen, device): + if self.attn_type == "bi": + beg, end = klen, -qlen + elif self.attn_type == "uni": + beg, end = klen, -1 + else: + raise ValueError(f"Unknown `attn_type` {self.attn_type}.") + + pos_seq = torch.arange(beg, end, -1.0, device=device, dtype=torch.long) + if self.clamp_len > 0: + pos_seq.clamp_(max=self.clamp_len) + fwd_pos_emb = self.pos_embeddings(pos_seq) + + if self.bi_data: + pos_seq = torch.arange(-beg, -end, -1.0, device=device, dtype=torch.long) + if self.clamp_len > 0: + pos_seq.clamp_(max=self.clamp_len) + bwd_pos_emb = self.pos_embeddings(pos_seq) + pos_emb = torch.cat([fwd_pos_emb, bwd_pos_emb], dim=0) + else: + pos_emb = fwd_pos_emb + + pos_emb = self.dropout(pos_emb) + return pos_emb + + def apply_final_layers(self, inputs): + hidden_state = super().apply_final_layers(inputs) + if self.with_lm: + return [hidden_state, self.dense(hidden_state)] + else: + return hidden_state + + def load_variable(self, state_dict, name, prefix="transformer"): + variable = state_dict[name] + if name in { + f"{prefix}.word_embedding.weight", + "lm_loss.weight", + "lm_loss.bias", + }: + return self.load_embeddings(variable) + elif re.search("rel_attn\.(q|k|v|r)$", name): + return variable.reshape(variable.shape[0], -1).T + # elif re.search('rel_attn\.(o|seg_embed)$', name): + elif re.search("rel_attn\.(o)$", name): + return variable.reshape(variable.shape[0], -1) + else: + return variable + + def variable_mapping(self, prefix="transformer"): + mapping = { + "embeddings.weight": f"{prefix}.word_embedding.weight", + "dense.weight": "lm_loss.weight", + "dense.bias": "lm_loss.bias", + } + for i in range(self.num_hidden_layers): + prefix_i = f"{prefix}.layer.%d." % i + mapping.update( + { + f"encoderLayer.{i}.multiHeadAttention.q.weight": prefix_i + + "rel_attn.q", + f"encoderLayer.{i}.multiHeadAttention.k.weight": prefix_i + + "rel_attn.k", + f"encoderLayer.{i}.multiHeadAttention.v.weight": prefix_i + + "rel_attn.v", + f"encoderLayer.{i}.multiHeadAttention.o.weight": prefix_i + + "rel_attn.o", + f"encoderLayer.{i}.multiHeadAttention.r.weight": prefix_i + + "rel_attn.r", + f"encoderLayer.{i}.multiHeadAttention.r_r_bias": prefix_i + + "rel_attn.r_r_bias", + f"encoderLayer.{i}.multiHeadAttention.r_s_bias": prefix_i + + "rel_attn.r_s_bias", + f"encoderLayer.{i}.multiHeadAttention.r_w_bias": prefix_i + + "rel_attn.r_w_bias", + # f'encoderLayer.{i}.multiHeadAttention.seg_embed.weight': prefix_i + 'rel_attn.seg_embed', + f"encoderLayer.{i}.multiHeadAttention.seg_embed": prefix_i + + "rel_attn.seg_embed", + f"encoderLayer.{i}.layerNorm1.weight": prefix_i + + "rel_attn.layer_norm.weight", + f"encoderLayer.{i}.layerNorm1.bias": prefix_i + + "rel_attn.layer_norm.bias", + f"encoderLayer.{i}.feedForward.intermediateDense.weight": prefix_i + + "ff.layer_1.weight", + f"encoderLayer.{i}.feedForward.intermediateDense.bias": prefix_i + + "ff.layer_1.bias", + f"encoderLayer.{i}.feedForward.outputDense.weight": prefix_i + + "ff.layer_2.weight", + f"encoderLayer.{i}.feedForward.outputDense.bias": prefix_i + + "ff.layer_2.bias", + f"encoderLayer.{i}.layerNorm2.weight": prefix_i + + "ff.layer_norm.weight", + f"encoderLayer.{i}.layerNorm2.bias": prefix_i + + "ff.layer_norm.bias", + } + ) + + return mapping + + +def build_transformer_model( + config_path=None, + checkpoint_path=None, + model="bert", + application="encoder", + **kwargs, +): + + configs = {} + if config_path is not None: + configs.update(json.load(open(config_path))) + configs.update(kwargs) + if "max_position" not in configs: + configs["max_position"] = configs.get("max_position_embeddings", 512) + if "dropout_rate" not in configs: + configs["dropout_rate"] = configs.get("hidden_dropout_prob") + if "segment_vocab_size" not in configs: + configs["segment_vocab_size"] = configs.get("type_vocab_size", 2) + + models = { + "bert": BERT, + "roberta": BERT, + "albert": ALBERT, + "albert_unshared": ALBERT_Unshared, + "nezha": NEZHA, + "roformer": RoFormer, + "roformer_v2": RoFormerV2, + "gau_alpha": GAU_alpha, + "electra": ELECTRA, + "encoder": Encoder, + "decoder": Decoder, + "transformer": Transformer, + "bart": BART, + "gpt": GPT, + "gpt2": GPT2, + "gpt2_ml": GPT2_ML, + "t5": T5, + "t5_encoder": T5_Encoder, + "t5_decoder": T5_Decoder, + "t5.1.0": T5, + "t5.1.0_encoder": T5_Encoder, + "t5.1.0_decoder": T5_Decoder, + "t5.1.1": T5, + "t5.1.1_encoder": T5_Encoder, + "t5.1.1_decoder": T5_Decoder, + "mt5.1.1": T5, + "mt5.1.1_encoder": T5_Encoder, + "mt5.1.1_decoder": T5_Decoder, + "transformer_xl": Transformer_XL, + "xlnet": XLNET, + } + + if isinstance(model, str): + MODEL = models[model.lower()] + if model.endswith("t5.1.1"): + configs["version"] = model + elif isinstance(model, type) and issubclass( + model, BERT_BASE + ): + MODEL = model + else: + raise ValueError('"model" args type should be string or nn.Module') + + application = application.lower() + if application in ["lm", "unilm"] and model in [ + "electra", + "t5", + ]: + raise ValueError( + f'"{model}" model can not be used as "{application}" application.\n' + ) + + if application == "lm": + MODEL = extend_with_language_model(MODEL) + elif application == "unilm": + MODEL = extend_with_unified_language_model(MODEL) + + transformer = MODEL(**configs) + transformer.build(**configs) + transformer.apply(transformer.init_model_weights) + + if checkpoint_path is not None: + transformer.load_weights_from_pytorch_checkpoint(checkpoint_path) + transformer.configs = configs + return transformer diff --git a/models/nlp/plm/bert_base_ner/igie/Int8QAT/bert4torch/snippets.py b/models/nlp/plm/bert_base_ner/igie/Int8QAT/bert4torch/snippets.py index b42628ba9bfa5df6ba40e3eb6fc300a6641e46d4..3f52fb014ea8d77db8c5dc0443e3c9add82bc7ba 100644 --- a/models/nlp/plm/bert_base_ner/igie/Int8QAT/bert4torch/snippets.py +++ b/models/nlp/plm/bert_base_ner/igie/Int8QAT/bert4torch/snippets.py @@ -1,1184 +1,1184 @@ -# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. -# All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); you may -# not use this file except in compliance with the License. You may obtain -# a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -# License for the specific language governing permissions and limitations -# under the License. - -import collections -import gc -import inspect -import json -import math -import os -import random -import re -import sys -import time -import unicodedata -import warnings - -import numpy as np -import six -import torch -import torch.nn as nn -import torch.nn.functional as F -from torch.nn.utils.rnn import pad_sequence -from torch.utils.data import Dataset, IterableDataset - -is_py2 = six.PY2 - -if not is_py2: - basestring = str - - -def take_along_dim(input_tensor, indices, dim=None): - if torch.__version__ >= "1.9.0": - return torch.take_along_dim(input_tensor, indices, dim) - else: - if dim is None: - res = input_tensor.flatten()[indices] - else: - res = np.take_along_axis( - input_tensor.cpu().numpy(), indices.cpu().numpy(), axis=dim - ) - res = torch.from_numpy(res).to(input_tensor.device) - return res - - -def is_string(s): - return isinstance(s, basestring) - - -def truncate_sequences(maxlen, indices, *sequences): - sequences = [s for s in sequences if s] - if not isinstance(indices, (list, tuple)): - indices = [indices] * len(sequences) - - while True: - lengths = [len(s) for s in sequences] - if sum(lengths) > maxlen: - i = np.argmax(lengths) - sequences[i].pop(indices[i]) - else: - return sequences - - -def text_segmentate(text, maxlen, seps="\n", strips=None, truncate=True): - text = text.strip().strip(strips) - if seps and len(text) > maxlen: - pieces = text.split(seps[0]) - text, texts = "", [] - for i, p in enumerate(pieces): - if text and p and len(text) + len(p) > maxlen - 1: - texts.extend(text_segmentate(text, maxlen, seps[1:], strips, truncate)) - text = "" - if i + 1 == len(pieces): - text = text + p - else: - text = text + p + seps[0] - if text: - texts.extend(text_segmentate(text, maxlen, seps[1:], strips, truncate)) - return texts - elif truncate and (not seps) and (len(text) > maxlen): - return [ - text[i * maxlen : (i + 1) * maxlen] - for i in range(0, int(np.ceil(len(text) / maxlen))) - ] - else: - return [text] - - -def merge_segmentate(sequences, maxlen, sep=""): - sequences_new = [] - text = "" - for t in sequences: - if text and len(text + sep + t) <= maxlen: - text = text + sep + t - elif text: - sequences_new.append(text) - text = t - elif len(t) < maxlen: - text = t - else: - sequences_new.append(t) - text = "" - if text: - sequences_new.append(text) - return sequences_new - - -def text_augmentation( - texts, - noise_dict=None, - noise_len=0, - noise_p=0.0, - skip_words=None, - strategy="random", - allow_dup=True, -): - def insert(text, insert_idx, noise_dict): - text = list(text) - for i in insert_idx: - text[i] = text[i] + random.choice(noise_dict) - return "".join(text) - - def delete(text, delete_idx): - text = list(text) - for i in delete_idx: - text[i] = "" - return "".join(text) - - def replace(text, replace_idx, noise_dict): - text = list(text) - for i in replace_idx: - text[i] = random.choice(noise_dict) - return "".join(text) - - def search(pattern, sequence, keep_last=True): - n = len(pattern) - pattern_idx_set = set() - for i in range(len(sequence)): - if sequence[i : i + n] == pattern: - pattern_idx_set = ( - pattern_idx_set.union(set(range(i, i + n))) - if keep_last - else pattern_idx_set.union(set(range(i, i + n - 1))) - ) - return pattern_idx_set - - if (noise_len == 0) and (noise_p == 0): - return texts - - assert strategy in { - "insert", - "delete", - "replace", - "random", - }, "EDA strategy only support insert, delete, replace, random" - - if isinstance(texts, str): - texts = [texts] - - if skip_words is None: - skip_words = [] - elif isinstance(skip_words, str): - skip_words = [skip_words] - - for id, text in enumerate(texts): - sel_len = noise_len if noise_len > 0 else int(len(text) * noise_p) - skip_idx = set() - for item in skip_words: - skip_idx = skip_idx.union(search(item, text, strategy != "insert")) - - sel_idxs = [i for i in range(len(text)) if i not in skip_idx] - sel_len = ( - sel_len if allow_dup else min(sel_len, len(sel_idxs)) - ) - if (sel_len == 0) or (len(sel_idxs) == 0): - continue - sel_idx = np.random.choice(sel_idxs, sel_len, replace=allow_dup) - if strategy == "insert": - texts[id] = insert(text, sel_idx, noise_dict) - elif strategy == "delete": - texts[id] = delete(text, sel_idx) - elif strategy == "replace": - texts[id] = replace(text, sel_idx, noise_dict) - elif strategy == "random": - if random.random() < 0.333: - skip_idx = set() - for item in skip_words: - skip_idx = skip_idx.union(search(item, text, keep_last=False)) - texts[id] = insert(text, sel_idx, noise_dict) - elif random.random() < 0.667: - texts[id] = delete(text, sel_idx) - else: - texts[id] = replace(text, sel_idx, noise_dict) - return texts if len(texts) > 1 else texts[0] - - -def lowercase_and_normalize(text, never_split=()): - if is_py2: - text = unicode(text) - - # convert non-special tokens to lowercase - escaped_special_toks = [re.escape(s_tok) for s_tok in never_split] - pattern = r"(" + r"|".join(escaped_special_toks) + r")|" + r"(.+?)" - text = re.sub(pattern, lambda m: m.groups()[0] or m.groups()[1].lower(), text) - - text = unicodedata.normalize("NFD", text) - text = "".join([ch for ch in text if unicodedata.category(ch) != "Mn"]) - return text - - -def sequence_padding(inputs, length=None, value=0, seq_dims=1, mode="post"): - if isinstance(inputs[0], (np.ndarray, list)): - if length is None: - length = np.max([np.shape(x)[:seq_dims] for x in inputs], axis=0) - elif not hasattr(length, "__getitem__"): - length = [length] - - slices = [np.s_[: length[i]] for i in range(seq_dims)] - slices = tuple(slices) if len(slices) > 1 else slices[0] - pad_width = [(0, 0) for _ in np.shape(inputs[0])] - - outputs = [] - for x in inputs: - x = x[slices] - for i in range(seq_dims): - if mode == "post": - pad_width[i] = (0, length[i] - np.shape(x)[i]) - elif mode == "pre": - pad_width[i] = (length[i] - np.shape(x)[i], 0) - else: - raise ValueError('"mode" argument must be "post" or "pre".') - x = np.pad(x, pad_width, "constant", constant_values=value) - outputs.append(x) - - return np.array(outputs) - - elif isinstance(inputs[0], torch.Tensor): - assert ( - mode == "post" - ), '"mode" argument must be "post" when element is torch.Tensor' - if length is not None: - inputs = [i[:length] for i in inputs] - return pad_sequence(inputs, padding_value=value, batch_first=True) - else: - raise ValueError('"input" argument must be tensor/list/ndarray.') - - -def insert_arguments(**arguments): - def actual_decorator(func): - def new_func(self, *args, **kwargs): - for k, v in arguments.items(): - if k in kwargs: - v = kwargs.pop(k) - setattr(self, k, v) - return func(self, *args, **kwargs) - - return new_func - - return actual_decorator - - -def delete_arguments(*arguments): - def actual_decorator(func): - def new_func(self, *args, **kwargs): - for k in arguments: - if k in kwargs: - raise TypeError( - "%s got an unexpected keyword argument '%s'" - % (self.__class__.__name__, k) - ) - return func(self, *args, **kwargs) - - return new_func - - return actual_decorator - - -class Progbar(object): - """Displays a progress bar. - - # Arguments - target: Total number of steps expected, None if unknown. - width: Progress bar width on screen. - verbose: Verbosity mode, 0 (silent), 1 (verbose), 2 (semi-verbose) - stateful_metrics: Iterable of string names of metrics that - should *not* be averaged over time. Metrics in this list - will be displayed as-is. All others will be averaged - by the progbar before display. - interval: Minimum visual progress update interval (in seconds). - """ - - def __init__( - self, target, width=30, verbose=1, interval=0.05, stateful_metrics=None - ): - self.target = target - self.width = width - self.verbose = verbose - self.interval = interval - if stateful_metrics: - self.stateful_metrics = set(stateful_metrics) - else: - self.stateful_metrics = set() - - self._dynamic_display = ( - hasattr(sys.stdout, "isatty") and sys.stdout.isatty() - ) or "ipykernel" in sys.modules - self._total_width = 0 - self._seen_so_far = 0 - self._values = collections.OrderedDict() - self._start = time.time() - self._last_update = 0 - - def update(self, current, values=None): - """Updates the progress bar. - - # Arguments - current: Index of current step. - values: List of tuples: - `(name, value_for_last_step)`. - If `name` is in `stateful_metrics`, - `value_for_last_step` will be displayed as-is. - Else, an average of the metric over time will be displayed. - """ - values = values or [] - for k, v in values: - if k not in self.stateful_metrics: - if k not in self._values: - self._values[k] = [ - v * (current - self._seen_so_far), - current - self._seen_so_far, - ] - else: - self._values[k][0] += v * (current - self._seen_so_far) - self._values[k][1] += current - self._seen_so_far - else: - # Stateful metrics output a numeric value. This representation - # means "take an average from a single value" but keeps the - # numeric formatting. - self._values[k] = [v, 1] - self._seen_so_far = current - - now = time.time() - info = " - %.0fs" % (now - self._start) - if self.verbose == 1: - if ( - now - self._last_update < self.interval - and self.target is not None - and current < self.target - ): - return - - prev_total_width = self._total_width - if self._dynamic_display: - sys.stdout.write("\b" * prev_total_width) - sys.stdout.write("\r") - else: - sys.stdout.write("\n") - - if self.target is not None: - numdigits = int(np.floor(np.log10(self.target))) + 1 - barstr = "%%%dd/%d [" % (numdigits, self.target) - bar = barstr % current - prog = float(current) / self.target - prog_width = int(self.width * prog) - if prog_width > 0: - bar += "=" * (prog_width - 1) - if current < self.target: - bar += ">" - else: - bar += "=" - bar += "." * (self.width - prog_width) - bar += "]" - else: - bar = "%7d/Unknown" % current - - self._total_width = len(bar) - sys.stdout.write(bar) - - if current: - time_per_unit = (now - self._start) / current - else: - time_per_unit = 0 - if self.target is not None and current < self.target: - eta = time_per_unit * (self.target - current) - if eta > 3600: - eta_format = "%d:%02d:%02d" % ( - eta // 3600, - (eta % 3600) // 60, - eta % 60, - ) - elif eta > 60: - eta_format = "%d:%02d" % (eta // 60, eta % 60) - else: - eta_format = "%ds" % eta - - info = " - ETA: %s" % eta_format - else: - if time_per_unit >= 1: - info += " %.0fs/step" % time_per_unit - elif time_per_unit >= 1e-3: - info += " %.0fms/step" % (time_per_unit * 1e3) - else: - info += " %.0fus/step" % (time_per_unit * 1e6) - - for k in self._values: - info += " - %s:" % k - if isinstance(self._values[k], list): - avg = np.mean(self._values[k][0] / max(1, self._values[k][1])) - if abs(avg) > 1e-3: - info += " %.4f" % avg - else: - info += " %.4e" % avg - else: - info += " %s" % self._values[k] - - self._total_width += len(info) - if prev_total_width > self._total_width: - info += " " * (prev_total_width - self._total_width) - - if self.target is not None and current >= self.target: - info += "\n" - - sys.stdout.write(info) - sys.stdout.flush() - - elif self.verbose == 2: - if self.target is None or current >= self.target: - for k in self._values: - info += " - %s:" % k - avg = np.mean(self._values[k][0] / max(1, self._values[k][1])) - if avg > 1e-3: - info += " %.4f" % avg - else: - info += " %.4e" % avg - info += "\n" - - sys.stdout.write(info) - sys.stdout.flush() - - self._last_update = now - - def add(self, n, values=None): - self.update(self._seen_so_far + n, values) - - -class Callback(object): - """Callback基类""" - - def __init__(self): - pass - - def on_train_begin(self, logs=None): - pass - - def on_train_end(self, logs=None): - pass - - def on_epoch_begin(self, global_step, epoch, logs=None): - pass - - def on_epoch_end(self, global_step, epoch, logs=None): - pass - - def on_batch_begin(self, global_step, batch, logs=None): - pass - - def on_batch_end(self, global_step, batch, logs=None): - pass - - def on_dataloader_end(self, logs=None): - pass - - -class ProgbarLogger(Callback): - """Callback that prints metrics to stdout. - - # Arguments - count_mode: One of "steps" or "samples". - Whether the progress bar should - count samples seen or steps (batches) seen. - stateful_metrics: Iterable of string names of metrics that - should *not* be averaged over an epoch. - Metrics in this list will be logged as-is. - All others will be averaged over time (e.g. loss, etc). - - # Raises - ValueError: In case of invalid `count_mode`. - """ - - def __init__(self, epochs, steps, metrics, stateful_metrics=None, verbose=1): - super(ProgbarLogger, self).__init__() - if stateful_metrics: - self.stateful_metrics = set(stateful_metrics) - else: - self.stateful_metrics = set() - self.params = { - "epochs": epochs, - "steps": steps, - "verbose": verbose, - "metrics": metrics, - } - self.verbose = verbose - self.epochs = epochs - - def add_metrics(self, metrics, add_position=None): - if add_position is None: - add_position = len(self.params["metrics"]) - if isinstance(metrics, str): - metrics = [metrics] - - add_metrics = [] - for metric in metrics: - if metric not in self.params["metrics"]: - add_metrics.append(metric) - self.params["metrics"] = ( - self.params["metrics"][:add_position] - + add_metrics - + self.params["metrics"][add_position:] - ) - - def on_train_begin(self, logs=None): - if self.verbose: - print("Start Training".center(40, "=")) - - def on_epoch_begin(self, global_step=None, epoch=None, logs=None): - if self.verbose: - print("Epoch %d/%d" % (epoch + 1, self.epochs)) - self.target = self.params["steps"] - self.progbar = Progbar( - target=self.target, - verbose=self.verbose, - stateful_metrics=self.stateful_metrics, - ) - self.seen = 0 - - def on_batch_begin(self, global_step=None, batch=None, logs=None): - if self.seen < self.target: - self.log_values = [] - - def on_batch_end(self, global_step=None, batch=None, logs=None): - logs = logs or {} - self.seen += 1 - for k in self.params["metrics"]: - if k in logs: - self.log_values.append((k, logs[k])) - - # Skip progbar update for the last batch; - # will be handled by on_epoch_end. - if self.verbose and self.seen < self.target: - self.progbar.update(self.seen, self.log_values) - - def on_epoch_end(self, global_step=None, epoch=None, logs=None): - logs = logs or {} - for k in self.params["metrics"]: - if k in logs: - self.log_values.append((k, logs[k])) - if self.verbose: - self.progbar.update(self.seen, self.log_values) - - def on_train_end(self, logs=None): - if self.verbose: - print("Finish Training".center(40, "=")) - - -class EarlyStopping(Callback): - def __init__( - self, - monitor="loss", - min_delta=0, - patience=0, - verbose=0, - mode="auto", - baseline=None, - ): - super(EarlyStopping, self).__init__() - - self.monitor = monitor - self.baseline = baseline - self.patience = patience - self.verbose = verbose - self.min_delta = min_delta - self.wait = 0 - self.stopped_epoch = 0 - - if mode not in ["auto", "min", "max"]: - warnings.warn( - "EarlyStopping mode %s is unknown, fallback to auto mode." % mode, - RuntimeWarning, - ) - mode = "auto" - - if mode == "min": - self.monitor_op = np.less - elif mode == "max": - self.monitor_op = np.greater - else: - self.monitor_op = np.greater if "acc" in self.monitor else np.less - self.min_delta = ( - self.min_delta if self.monitor_op == np.greater else -self.min_delta - ) - - def on_train_begin(self, logs=None): - # Allow instances to be re-used - self.wait = 0 - self.stopped_epoch = 0 - if self.baseline is not None: - self.best = self.baseline - else: - self.best = np.Inf if self.monitor_op == np.less else -np.Inf - - def on_epoch_end(self, steps, epoch, logs=None): - current = self.get_monitor_value(logs) - if current is None: - return - - if self.monitor_op(current - self.min_delta, self.best): - self.best = current - self.wait = 0 - else: - self.wait += 1 - if self.wait >= self.patience: - self.stopped_epoch = epoch - - def on_train_end(self, logs=None): - if self.stopped_epoch > 0 and self.verbose > 0: - print(f"Epoch {self.stopped_epoch+1}: early stopping\n") - - def get_monitor_value(self, logs): - monitor_value = logs.get(self.monitor) - if monitor_value is None: - warnings.warn( - "Early stopping conditioned on metric `%s` " - "which is not available. Available metrics are: %s" - % (self.monitor, ",".join(list(logs.keys()))), - RuntimeWarning, - ) - return monitor_value - - -def metric_mapping(metric, y_pred, y_true): - if metric == "accuracy": - if isinstance(y_pred, (list, tuple)): - y_pred = y_pred[0] - y_pred = torch.argmax(y_pred, dim=-1) - acc = torch.sum(y_pred.eq(y_true)).item() / y_true.size(0) - return acc - return None - - -def softmax(x, axis=-1): - x = x - x.max(axis=axis, keepdims=True) - x = np.exp(x) - return x / x.sum(axis=axis, keepdims=True) - - -class AutoRegressiveDecoder(object): - def __init__(self, start_id, end_id, maxlen, minlen=1, device="cpu"): - self.start_id = start_id - self.end_id = end_id - self.maxlen = maxlen - self.minlen = minlen - self.models = {} - self.device = device - if start_id is None: - self.first_output_ids = torch.empty((1, 0), dtype=int, device=device) - else: - self.first_output_ids = torch.tensor([[self.start_id]], device=device) - - @staticmethod - def wraps(default_rtype="probas", use_states=False): - def actual_decorator(predict): - def new_predict( - self, inputs, output_ids, states, temperature=1, rtype=default_rtype - ): - assert rtype in ["probas", "logits"] - prediction = predict(self, inputs, output_ids, states) - - if not use_states: - prediction = (prediction, None) - - if default_rtype == "logits": - prediction = ( - nn.Softmax(dim=-1)(prediction[0] / temperature), - prediction[1], - ) - elif temperature != 1: - probas = torch.power(prediction[0], 1.0 / temperature) - probas = probas / probas.sum(axis=-1, keepdims=True) - prediction = (probas, prediction[1]) - - if rtype == "probas": - return prediction - else: - return torch.log(prediction[0] + 1e-12), prediction[1] - - return new_predict - - return actual_decorator - - def predict(self, inputs, output_ids, states=None): - raise NotImplementedError - - def beam_search( - self, inputs_raw, topk, states=None, temperature=1, min_ends=1, add_btz_dim=True - ): - inputs = [] - for i in inputs_raw: - if isinstance(i, torch.torch.Tensor): - pass - elif isinstance(i, (list, tuple, np.ndarray)) and add_btz_dim: - i = torch.tensor([i], device=self.device) - elif isinstance(i, (list, tuple, np.ndarray)) and not add_btz_dim: - i = torch.tensor(i, device=self.device) - else: - raise ValueError( - "Beam search inputs ele only support tensor、array、list、tuple" - ) - inputs.append(i) - - output_ids, output_scores = self.first_output_ids, torch.zeros( - 1, device=self.device - ) - for step in range(self.maxlen): - scores, states = self.predict( - inputs, output_ids, states, temperature, "logits" - ) - if step == 0: - inputs = [i.repeat([topk] + [1] * (len(i.shape) - 1)) for i in inputs] - scores = output_scores.reshape((-1, 1)) + scores - indices = scores.flatten().argsort(dim=-1, descending=True)[:topk] - indices_1 = torch.div( - indices, scores.shape[1], rounding_mode="trunc" - ) - indices_2 = (indices % scores.shape[1]).reshape((-1, 1)) - output_ids = torch.cat([output_ids[indices_1], indices_2], 1) - output_scores = take_along_dim(scores, indices, dim=None) - is_end = output_ids[:, -1] == self.end_id - end_counts = (output_ids == self.end_id).sum(1) - if output_ids.shape[1] >= self.minlen: - best = output_scores.argmax() - if is_end[best] and end_counts[best] >= min_ends: - return output_ids[best] - else: - flag = ~is_end | (end_counts < min_ends) - if not flag.all(): - inputs = [i[flag] for i in inputs] - output_ids = output_ids[flag] - output_scores = output_scores[flag] - end_counts = end_counts[flag] - topk = flag.sum() - return output_ids[output_scores.argmax()] - - def random_sample( - self, inputs, n, topk=None, topp=None, states=None, temperature=1, min_ends=1 - ): - inputs = [torch.tensor([i], device=self.device) for i in inputs] - output_ids = self.first_output_ids - results = [] - for step in range(self.maxlen): - probas, states = self.predict( - inputs, output_ids, states, temperature, "probas" - ) - probas /= probas.sum(dim=-1, keepdims=True) - if step == 0: - probas = probas.repeat([n] + [1] * (len(probas.shape) - 1)) - inputs = [i.repeat([n] + [1] * (len(i.shape) - 1)) for i in inputs] - output_ids = output_ids.repeat([n] + [1] * (len(output_ids.shape) - 1)) - if topk is not None: - k_indices = probas.argsort(dim=-1, descending=True)[:, :topk] - probas = take_along_dim(probas, k_indices, dim=1) - probas /= probas.sum(dim=1, keepdims=True) - if topp is not None: - p_indices = probas.argsort(dim=-1, descending=True) - probas = take_along_dim(probas, p_indices, dim=-1) - cumsum_probas = torch.cumsum(probas, dim=-1) - flag = torch.roll(cumsum_probas >= topp, 1, dims=1) - flag[:, 0] = False - probas[flag] = 0 - probas /= probas.sum(dim=1, keepdims=True) - - sample_func = lambda p: torch.multinomial(p, 1) - sample_ids = torch.stack([sample_func(p) for p in probas]) - sample_ids = sample_ids.reshape((-1, 1)) - if topp is not None: - sample_ids = take_along_dim(p_indices, sample_ids, dim=1) - if topk is not None: - sample_ids = take_along_dim(k_indices, sample_ids, dim=1) - output_ids = torch.cat([output_ids, sample_ids], 1) - is_end = output_ids[:, -1] == self.end_id - end_counts = (output_ids == self.end_id).sum(1) - if output_ids.shape[1] >= self.minlen: - flag = is_end & (end_counts >= min_ends) - if flag.any(): - for ids in output_ids[flag]: - results.append(ids) - flag = flag == False - inputs = [i[flag] for i in inputs] - output_ids = output_ids[flag] - end_counts = end_counts[flag] - if len(output_ids) == 0: - break - for ids in output_ids: - results.append(ids) - return results - - -def search_layer(model, layer_name, retrun_first=True): - return_list = [] - for name, param in model.named_parameters(): - if param.requires_grad and layer_name in name: - return_list.append(param) - if len(return_list) == 0: - return None - if retrun_first: - return return_list[0] - else: - return return_list - - -class ListDataset(Dataset): - def __init__(self, file_path=None, data=None, **kwargs): - self.kwargs = kwargs - if isinstance(file_path, (str, list)): - self.data = self.load_data(file_path) - elif isinstance(data, list): - self.data = data - else: - raise ValueError( - "The input args shall be str format file_path / list format dataset" - ) - - def __len__(self): - return len(self.data) - - def __getitem__(self, index): - return self.data[index] - - @staticmethod - def load_data(file_path): - return file_path - - -class IterDataset(IterableDataset): - """流式读取文件""" - - def __init__(self, file_path=None, **kwargs): - self.kwargs = kwargs - if isinstance(file_path, (str, list)): - self.file_path = file_path - else: - raise ValueError( - "The input args shall be str format file_path / list format dataset" - ) - - def __iter__(self): - return self.load_data(self.file_path) - - @staticmethod - def load_data(file_path): - return file_path - - -# sinusoid编码 -def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None): - """Returns: [seq_len, d_hid]""" - position = torch.arange(0, n_position, dtype=torch.float).unsqueeze(1) - div_term = torch.exp( - torch.arange(0, d_hid, 2).float() * (-math.log(10000.0) / d_hid) - ) - embeddings_table = torch.zeros(n_position, d_hid) - embeddings_table[:, 0::2] = torch.sin(position * div_term) - embeddings_table[:, 1::2] = torch.cos(position * div_term) - return embeddings_table - -def cal_ts_num(tensor_shape): - cal_num = 0 - for obj in gc.get_objects(): - try: - if torch.is_tensor( - obj - ): # or (hasattr(obj, 'data') and torch.is_tensor(obj.data)): - tensor = obj - else: - continue - if tensor.is_cuda and tensor.size() == tensor_shape: - print(tensor.shape) - cal_num += 1 - except Exception as e: - print("A trivial exception occured: {}".format(e)) - print(cal_num) - - -def get_kw(cls, kwargs): - kwargs_new = {} - for k in kwargs: - if k not in set(inspect.getargspec(cls)[0]): - kwargs_new[k] = kwargs[k] - return kwargs_new - - -class FGM: - def __init__(self, model): - self.model = model - self.backup = {} - - def attack(self, epsilon=1.0, emb_name="word_embeddings", **kwargs): - for name, param in self.model.named_parameters(): - if param.requires_grad and emb_name in name: - self.backup[name] = param.data.clone() - norm = torch.norm(param.grad) - if norm != 0 and not torch.isnan(norm): - r_at = epsilon * param.grad / norm - param.data.add_(r_at) - - def restore(self, emb_name="emb", **kwargs): - for name, param in self.model.named_parameters(): - if param.requires_grad and emb_name in name: - assert name in self.backup - param.data = self.backup[name] - self.backup = {} - -class PGD: - def __init__(self, model): - self.model = model - self.emb_backup = {} - self.grad_backup = {} - - def attack( - self, - epsilon=1.0, - alpha=0.3, - emb_name="word_embeddings", - is_first_attack=False, - **kwargs, - ): - for name, param in self.model.named_parameters(): - if param.requires_grad and emb_name in name: - if is_first_attack: - self.emb_backup[name] = param.data.clone() - norm = torch.norm(param.grad) - if norm != 0 and not torch.isnan(norm): - r_at = alpha * param.grad / norm - param.data.add_(r_at) - param.data = self.project(name, param.data, epsilon) - - def restore(self, emb_name="emb", **kwargs): - for name, param in self.model.named_parameters(): - if param.requires_grad and emb_name in name: - assert name in self.emb_backup - param.data = self.emb_backup[name] - self.emb_backup = {} - - def project(self, param_name, param_data, epsilon): - r = param_data - self.emb_backup[param_name] - if torch.norm(r) > epsilon: - r = epsilon * r / torch.norm(r) - return self.emb_backup[param_name] + r - - def backup_grad(self): - for name, param in self.model.named_parameters(): - if param.requires_grad and (param.grad is not None): - self.grad_backup[name] = param.grad.clone() - - def restore_grad(self): - for name, param in self.model.named_parameters(): - if param.requires_grad and (param.grad is not None): - param.grad = self.grad_backup[name] - - -class VAT: - def __init__( - self, - model, - emb_name="word_embeddings", - noise_var=1e-5, - noise_gamma=1e-6, - adv_step_size=1e-3, - adv_alpha=1, - norm_type="l2", - **kwargs, - ): - self.model = model - self.noise_var = noise_var - self.noise_gamma = noise_gamma - self.adv_step_size = adv_step_size - self.adv_alpha = adv_alpha - self.norm_type = norm_type - self.embed = None - for (name, module) in self.model.named_modules(): - if emb_name in name: - module.register_forward_hook(hook=self.hook) - - def hook(self, module, fea_in, fea_out): - self.embed = fea_out - return None - - def forward_(self, train_X, new_embed): - if isinstance(train_X, (tuple, list)): - new_train_X = [new_embed] + train_X[1:] - adv_output = ( - self.model.forward(*new_train_X) - if self.model.forward.__code__.co_argcount >= 3 - else self.model.forward(new_train_X) - ) - elif isinstance(train_X, torch.Tensor): - adv_output = self.model.forward(new_embed) - return adv_output - - def virtual_adversarial_training(self, train_X, logits): - noise = self.embed.data.new(self.embed.size()).normal_(0, 1) * self.noise_var - noise.requires_grad_() - # x + r - new_embed = self.embed.data.detach() + noise - adv_output = self.forward_(train_X, new_embed) # forward第一次 - adv_logits = ( - adv_output[0] if isinstance(adv_output, (list, tuple)) else adv_output - ) - adv_loss = self.kl(adv_logits, logits.detach(), reduction="batchmean") - (delta_grad,) = torch.autograd.grad(adv_loss, noise, only_inputs=True) - norm = delta_grad.norm() - if torch.isnan(norm) or torch.isinf(norm): - return None - # inner sum - noise = noise + delta_grad * self.adv_step_size - # projection - noise = self.adv_project(noise, norm_type=self.norm_type, eps=self.noise_gamma) - new_embed = self.embed.data.detach() + noise - new_embed = new_embed.detach() - adv_output = self.forward_(train_X, new_embed) - adv_logits = ( - adv_output[0] if isinstance(adv_output, (list, tuple)) else adv_output - ) - adv_loss_f = self.kl(adv_logits, logits.detach()) - adv_loss_b = self.kl(logits, adv_logits.detach()) - adv_loss = (adv_loss_f + adv_loss_b) * self.adv_alpha - return adv_loss - - @staticmethod - def kl(inputs, targets, reduction="sum"): - loss = F.kl_div( - F.log_softmax(inputs, dim=-1), - F.softmax(targets, dim=-1), - reduction=reduction, - ) - return loss - - @staticmethod - def adv_project(grad, norm_type="inf", eps=1e-6): - if norm_type == "l2": - direction = grad / (torch.norm(grad, dim=-1, keepdim=True) + eps) - elif norm_type == "l1": - direction = grad.sign() - else: - direction = grad / (grad.abs().max(-1, keepdim=True)[0] + eps) - return direction - - -class WebServing(object): - def __init__(self, host="0.0.0.0", port=8000, server="paste"): - - import bottle - - self.host = host - self.port = port - self.server = server - self.bottle = bottle - - def wraps(self, func, arguments, method="GET"): - def new_func(): - outputs = {"code": 0, "desc": "succeeded", "data": {}} - kwargs = {} - for key, value in arguments.items(): - if method == "GET": - result = self.bottle.request.GET.getunicode(key) - else: - result = self.bottle.request.POST.getunicode(key) - if result is None: - if value[1]: - outputs["code"] = 1 - outputs["desc"] = 'lack of "%s" argument' % key - return json.dumps(outputs, ensure_ascii=False) - else: - if value[0] is not None: - result = value[0](result) - kwargs[key] = result - try: - outputs["data"] = func(**kwargs) - except Exception as e: - outputs["code"] = 2 - outputs["desc"] = str(e) - return json.dumps(outputs, ensure_ascii=False) - - return new_func - - def route(self, path, func, arguments, method="GET"): - func = self.wraps(func, arguments, method) - self.bottle.route(path, method=method)(func) - - def start(self): - self.bottle.run(host=self.host, port=self.port, server=self.server) - - -def get_pool_emb( - hidden_state=None, - pooler=None, - attention_mask=None, - pool_strategy="cls", - custom_layer=None, -): - if pool_strategy == "pooler": - return pooler - elif pool_strategy == "cls": - if isinstance(hidden_state, (list, tuple)): - hidden_state = hidden_state[-1] - assert isinstance( - hidden_state, torch.Tensor - ), f"{pool_strategy} strategy request tensor hidden_state" - return hidden_state[:, 0] - elif pool_strategy in {"last-avg", "mean"}: - if isinstance(hidden_state, (list, tuple)): - hidden_state = hidden_state[-1] - assert isinstance( - hidden_state, torch.Tensor - ), f"{pool_strategy} pooling strategy request tensor hidden_state" - hid = torch.sum(hidden_state * attention_mask[:, :, None], dim=1) - attention_mask = torch.sum(attention_mask, dim=1)[:, None] - return hid / attention_mask - elif pool_strategy in {"last-max", "max"}: - if isinstance(hidden_state, (list, tuple)): - hidden_state = hidden_state[-1] - assert isinstance( - hidden_state, torch.Tensor - ), f"{pool_strategy} pooling strategy request tensor hidden_state" - hid = hidden_state * attention_mask[:, :, None] - return torch.max(hid, dim=1) - elif pool_strategy == "first-last-avg": - assert isinstance( - hidden_state, list - ), f"{pool_strategy} pooling strategy request list hidden_state" - hid = torch.sum(hidden_state[1] * attention_mask[:, :, None], dim=1) - hid += torch.sum(hidden_state[-1] * attention_mask[:, :, None], dim=1) - attention_mask = torch.sum(attention_mask, dim=1)[:, None] - return hid / (2 * attention_mask) - elif pool_strategy == "custom": - assert isinstance( - hidden_state, list - ), f"{pool_strategy} pooling strategy request list hidden_state" - assert isinstance( - custom_layer, (int, list, tuple) - ), f"{pool_strategy} pooling strategy request int/list/tuple custom_layer" - custom_layer = [custom_layer] if isinstance(custom_layer, int) else custom_layer - hid = 0 - for i, layer in enumerate(custom_layer, start=1): - hid += torch.sum(hidden_state[layer] * attention_mask[:, :, None], dim=1) - attention_mask = torch.sum(attention_mask, dim=1)[:, None] - return hid / (i * attention_mask) - else: - raise ValueError("pool_strategy illegal") - - -def seed_everything(seed=None): - max_seed_value = np.iinfo(np.uint32).max - min_seed_value = np.iinfo(np.uint32).min - - if (seed is None) or not (min_seed_value <= seed <= max_seed_value): - random.randint(np.iinfo(np.uint32).min, np.iinfo(np.uint32).max) - os.environ["PYTHONHASHSEED"] = str(seed) - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) - torch.cuda.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - return seed +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import collections +import gc +import inspect +import json +import math +import os +import random +import re +import sys +import time +import unicodedata +import warnings + +import numpy as np +import six +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.utils.rnn import pad_sequence +from torch.utils.data import Dataset, IterableDataset + +is_py2 = six.PY2 + +if not is_py2: + basestring = str + + +def take_along_dim(input_tensor, indices, dim=None): + if torch.__version__ >= "1.9.0": + return torch.take_along_dim(input_tensor, indices, dim) + else: + if dim is None: + res = input_tensor.flatten()[indices] + else: + res = np.take_along_axis( + input_tensor.cpu().numpy(), indices.cpu().numpy(), axis=dim + ) + res = torch.from_numpy(res).to(input_tensor.device) + return res + + +def is_string(s): + return isinstance(s, basestring) + + +def truncate_sequences(maxlen, indices, *sequences): + sequences = [s for s in sequences if s] + if not isinstance(indices, (list, tuple)): + indices = [indices] * len(sequences) + + while True: + lengths = [len(s) for s in sequences] + if sum(lengths) > maxlen: + i = np.argmax(lengths) + sequences[i].pop(indices[i]) + else: + return sequences + + +def text_segmentate(text, maxlen, seps="\n", strips=None, truncate=True): + text = text.strip().strip(strips) + if seps and len(text) > maxlen: + pieces = text.split(seps[0]) + text, texts = "", [] + for i, p in enumerate(pieces): + if text and p and len(text) + len(p) > maxlen - 1: + texts.extend(text_segmentate(text, maxlen, seps[1:], strips, truncate)) + text = "" + if i + 1 == len(pieces): + text = text + p + else: + text = text + p + seps[0] + if text: + texts.extend(text_segmentate(text, maxlen, seps[1:], strips, truncate)) + return texts + elif truncate and (not seps) and (len(text) > maxlen): + return [ + text[i * maxlen : (i + 1) * maxlen] + for i in range(0, int(np.ceil(len(text) / maxlen))) + ] + else: + return [text] + + +def merge_segmentate(sequences, maxlen, sep=""): + sequences_new = [] + text = "" + for t in sequences: + if text and len(text + sep + t) <= maxlen: + text = text + sep + t + elif text: + sequences_new.append(text) + text = t + elif len(t) < maxlen: + text = t + else: + sequences_new.append(t) + text = "" + if text: + sequences_new.append(text) + return sequences_new + + +def text_augmentation( + texts, + noise_dict=None, + noise_len=0, + noise_p=0.0, + skip_words=None, + strategy="random", + allow_dup=True, +): + def insert(text, insert_idx, noise_dict): + text = list(text) + for i in insert_idx: + text[i] = text[i] + random.choice(noise_dict) + return "".join(text) + + def delete(text, delete_idx): + text = list(text) + for i in delete_idx: + text[i] = "" + return "".join(text) + + def replace(text, replace_idx, noise_dict): + text = list(text) + for i in replace_idx: + text[i] = random.choice(noise_dict) + return "".join(text) + + def search(pattern, sequence, keep_last=True): + n = len(pattern) + pattern_idx_set = set() + for i in range(len(sequence)): + if sequence[i : i + n] == pattern: + pattern_idx_set = ( + pattern_idx_set.union(set(range(i, i + n))) + if keep_last + else pattern_idx_set.union(set(range(i, i + n - 1))) + ) + return pattern_idx_set + + if (noise_len == 0) and (noise_p == 0): + return texts + + assert strategy in { + "insert", + "delete", + "replace", + "random", + }, "EDA strategy only support insert, delete, replace, random" + + if isinstance(texts, str): + texts = [texts] + + if skip_words is None: + skip_words = [] + elif isinstance(skip_words, str): + skip_words = [skip_words] + + for id, text in enumerate(texts): + sel_len = noise_len if noise_len > 0 else int(len(text) * noise_p) + skip_idx = set() + for item in skip_words: + skip_idx = skip_idx.union(search(item, text, strategy != "insert")) + + sel_idxs = [i for i in range(len(text)) if i not in skip_idx] + sel_len = ( + sel_len if allow_dup else min(sel_len, len(sel_idxs)) + ) + if (sel_len == 0) or (len(sel_idxs) == 0): + continue + sel_idx = np.random.choice(sel_idxs, sel_len, replace=allow_dup) + if strategy == "insert": + texts[id] = insert(text, sel_idx, noise_dict) + elif strategy == "delete": + texts[id] = delete(text, sel_idx) + elif strategy == "replace": + texts[id] = replace(text, sel_idx, noise_dict) + elif strategy == "random": + if random.random() < 0.333: + skip_idx = set() + for item in skip_words: + skip_idx = skip_idx.union(search(item, text, keep_last=False)) + texts[id] = insert(text, sel_idx, noise_dict) + elif random.random() < 0.667: + texts[id] = delete(text, sel_idx) + else: + texts[id] = replace(text, sel_idx, noise_dict) + return texts if len(texts) > 1 else texts[0] + + +def lowercase_and_normalize(text, never_split=()): + if is_py2: + text = unicode(text) + + # convert non-special tokens to lowercase + escaped_special_toks = [re.escape(s_tok) for s_tok in never_split] + pattern = r"(" + r"|".join(escaped_special_toks) + r")|" + r"(.+?)" + text = re.sub(pattern, lambda m: m.groups()[0] or m.groups()[1].lower(), text) + + text = unicodedata.normalize("NFD", text) + text = "".join([ch for ch in text if unicodedata.category(ch) != "Mn"]) + return text + + +def sequence_padding(inputs, length=None, value=0, seq_dims=1, mode="post"): + if isinstance(inputs[0], (np.ndarray, list)): + if length is None: + length = np.max([np.shape(x)[:seq_dims] for x in inputs], axis=0) + elif not hasattr(length, "__getitem__"): + length = [length] + + slices = [np.s_[: length[i]] for i in range(seq_dims)] + slices = tuple(slices) if len(slices) > 1 else slices[0] + pad_width = [(0, 0) for _ in np.shape(inputs[0])] + + outputs = [] + for x in inputs: + x = x[slices] + for i in range(seq_dims): + if mode == "post": + pad_width[i] = (0, length[i] - np.shape(x)[i]) + elif mode == "pre": + pad_width[i] = (length[i] - np.shape(x)[i], 0) + else: + raise ValueError('"mode" argument must be "post" or "pre".') + x = np.pad(x, pad_width, "constant", constant_values=value) + outputs.append(x) + + return np.array(outputs) + + elif isinstance(inputs[0], torch.Tensor): + assert ( + mode == "post" + ), '"mode" argument must be "post" when element is torch.Tensor' + if length is not None: + inputs = [i[:length] for i in inputs] + return pad_sequence(inputs, padding_value=value, batch_first=True) + else: + raise ValueError('"input" argument must be tensor/list/ndarray.') + + +def insert_arguments(**arguments): + def actual_decorator(func): + def new_func(self, *args, **kwargs): + for k, v in arguments.items(): + if k in kwargs: + v = kwargs.pop(k) + setattr(self, k, v) + return func(self, *args, **kwargs) + + return new_func + + return actual_decorator + + +def delete_arguments(*arguments): + def actual_decorator(func): + def new_func(self, *args, **kwargs): + for k in arguments: + if k in kwargs: + raise TypeError( + "%s got an unexpected keyword argument '%s'" + % (self.__class__.__name__, k) + ) + return func(self, *args, **kwargs) + + return new_func + + return actual_decorator + + +class Progbar(object): + """Displays a progress bar. + + # Arguments + target: Total number of steps expected, None if unknown. + width: Progress bar width on screen. + verbose: Verbosity mode, 0 (silent), 1 (verbose), 2 (semi-verbose) + stateful_metrics: Iterable of string names of metrics that + should *not* be averaged over time. Metrics in this list + will be displayed as-is. All others will be averaged + by the progbar before display. + interval: Minimum visual progress update interval (in seconds). + """ + + def __init__( + self, target, width=30, verbose=1, interval=0.05, stateful_metrics=None + ): + self.target = target + self.width = width + self.verbose = verbose + self.interval = interval + if stateful_metrics: + self.stateful_metrics = set(stateful_metrics) + else: + self.stateful_metrics = set() + + self._dynamic_display = ( + hasattr(sys.stdout, "isatty") and sys.stdout.isatty() + ) or "ipykernel" in sys.modules + self._total_width = 0 + self._seen_so_far = 0 + self._values = collections.OrderedDict() + self._start = time.time() + self._last_update = 0 + + def update(self, current, values=None): + """Updates the progress bar. + + # Arguments + current: Index of current step. + values: List of tuples: + `(name, value_for_last_step)`. + If `name` is in `stateful_metrics`, + `value_for_last_step` will be displayed as-is. + Else, an average of the metric over time will be displayed. + """ + values = values or [] + for k, v in values: + if k not in self.stateful_metrics: + if k not in self._values: + self._values[k] = [ + v * (current - self._seen_so_far), + current - self._seen_so_far, + ] + else: + self._values[k][0] += v * (current - self._seen_so_far) + self._values[k][1] += current - self._seen_so_far + else: + # Stateful metrics output a numeric value. This representation + # means "take an average from a single value" but keeps the + # numeric formatting. + self._values[k] = [v, 1] + self._seen_so_far = current + + now = time.time() + info = " - %.0fs" % (now - self._start) + if self.verbose == 1: + if ( + now - self._last_update < self.interval + and self.target is not None + and current < self.target + ): + return + + prev_total_width = self._total_width + if self._dynamic_display: + sys.stdout.write("\b" * prev_total_width) + sys.stdout.write("\r") + else: + sys.stdout.write("\n") + + if self.target is not None: + numdigits = int(np.floor(np.log10(self.target))) + 1 + barstr = "%%%dd/%d [" % (numdigits, self.target) + bar = barstr % current + prog = float(current) / self.target + prog_width = int(self.width * prog) + if prog_width > 0: + bar += "=" * (prog_width - 1) + if current < self.target: + bar += ">" + else: + bar += "=" + bar += "." * (self.width - prog_width) + bar += "]" + else: + bar = "%7d/Unknown" % current + + self._total_width = len(bar) + sys.stdout.write(bar) + + if current: + time_per_unit = (now - self._start) / current + else: + time_per_unit = 0 + if self.target is not None and current < self.target: + eta = time_per_unit * (self.target - current) + if eta > 3600: + eta_format = "%d:%02d:%02d" % ( + eta // 3600, + (eta % 3600) // 60, + eta % 60, + ) + elif eta > 60: + eta_format = "%d:%02d" % (eta // 60, eta % 60) + else: + eta_format = "%ds" % eta + + info = " - ETA: %s" % eta_format + else: + if time_per_unit >= 1: + info += " %.0fs/step" % time_per_unit + elif time_per_unit >= 1e-3: + info += " %.0fms/step" % (time_per_unit * 1e3) + else: + info += " %.0fus/step" % (time_per_unit * 1e6) + + for k in self._values: + info += " - %s:" % k + if isinstance(self._values[k], list): + avg = np.mean(self._values[k][0] / max(1, self._values[k][1])) + if abs(avg) > 1e-3: + info += " %.4f" % avg + else: + info += " %.4e" % avg + else: + info += " %s" % self._values[k] + + self._total_width += len(info) + if prev_total_width > self._total_width: + info += " " * (prev_total_width - self._total_width) + + if self.target is not None and current >= self.target: + info += "\n" + + sys.stdout.write(info) + sys.stdout.flush() + + elif self.verbose == 2: + if self.target is None or current >= self.target: + for k in self._values: + info += " - %s:" % k + avg = np.mean(self._values[k][0] / max(1, self._values[k][1])) + if avg > 1e-3: + info += " %.4f" % avg + else: + info += " %.4e" % avg + info += "\n" + + sys.stdout.write(info) + sys.stdout.flush() + + self._last_update = now + + def add(self, n, values=None): + self.update(self._seen_so_far + n, values) + + +class Callback(object): + """Callback基类""" + + def __init__(self): + pass + + def on_train_begin(self, logs=None): + pass + + def on_train_end(self, logs=None): + pass + + def on_epoch_begin(self, global_step, epoch, logs=None): + pass + + def on_epoch_end(self, global_step, epoch, logs=None): + pass + + def on_batch_begin(self, global_step, batch, logs=None): + pass + + def on_batch_end(self, global_step, batch, logs=None): + pass + + def on_dataloader_end(self, logs=None): + pass + + +class ProgbarLogger(Callback): + """Callback that prints metrics to stdout. + + # Arguments + count_mode: One of "steps" or "samples". + Whether the progress bar should + count samples seen or steps (batches) seen. + stateful_metrics: Iterable of string names of metrics that + should *not* be averaged over an epoch. + Metrics in this list will be logged as-is. + All others will be averaged over time (e.g. loss, etc). + + # Raises + ValueError: In case of invalid `count_mode`. + """ + + def __init__(self, epochs, steps, metrics, stateful_metrics=None, verbose=1): + super(ProgbarLogger, self).__init__() + if stateful_metrics: + self.stateful_metrics = set(stateful_metrics) + else: + self.stateful_metrics = set() + self.params = { + "epochs": epochs, + "steps": steps, + "verbose": verbose, + "metrics": metrics, + } + self.verbose = verbose + self.epochs = epochs + + def add_metrics(self, metrics, add_position=None): + if add_position is None: + add_position = len(self.params["metrics"]) + if isinstance(metrics, str): + metrics = [metrics] + + add_metrics = [] + for metric in metrics: + if metric not in self.params["metrics"]: + add_metrics.append(metric) + self.params["metrics"] = ( + self.params["metrics"][:add_position] + + add_metrics + + self.params["metrics"][add_position:] + ) + + def on_train_begin(self, logs=None): + if self.verbose: + print("Start Training".center(40, "=")) + + def on_epoch_begin(self, global_step=None, epoch=None, logs=None): + if self.verbose: + print("Epoch %d/%d" % (epoch + 1, self.epochs)) + self.target = self.params["steps"] + self.progbar = Progbar( + target=self.target, + verbose=self.verbose, + stateful_metrics=self.stateful_metrics, + ) + self.seen = 0 + + def on_batch_begin(self, global_step=None, batch=None, logs=None): + if self.seen < self.target: + self.log_values = [] + + def on_batch_end(self, global_step=None, batch=None, logs=None): + logs = logs or {} + self.seen += 1 + for k in self.params["metrics"]: + if k in logs: + self.log_values.append((k, logs[k])) + + # Skip progbar update for the last batch; + # will be handled by on_epoch_end. + if self.verbose and self.seen < self.target: + self.progbar.update(self.seen, self.log_values) + + def on_epoch_end(self, global_step=None, epoch=None, logs=None): + logs = logs or {} + for k in self.params["metrics"]: + if k in logs: + self.log_values.append((k, logs[k])) + if self.verbose: + self.progbar.update(self.seen, self.log_values) + + def on_train_end(self, logs=None): + if self.verbose: + print("Finish Training".center(40, "=")) + + +class EarlyStopping(Callback): + def __init__( + self, + monitor="loss", + min_delta=0, + patience=0, + verbose=0, + mode="auto", + baseline=None, + ): + super(EarlyStopping, self).__init__() + + self.monitor = monitor + self.baseline = baseline + self.patience = patience + self.verbose = verbose + self.min_delta = min_delta + self.wait = 0 + self.stopped_epoch = 0 + + if mode not in ["auto", "min", "max"]: + warnings.warn( + "EarlyStopping mode %s is unknown, fallback to auto mode." % mode, + RuntimeWarning, + ) + mode = "auto" + + if mode == "min": + self.monitor_op = np.less + elif mode == "max": + self.monitor_op = np.greater + else: + self.monitor_op = np.greater if "acc" in self.monitor else np.less + self.min_delta = ( + self.min_delta if self.monitor_op == np.greater else -self.min_delta + ) + + def on_train_begin(self, logs=None): + # Allow instances to be re-used + self.wait = 0 + self.stopped_epoch = 0 + if self.baseline is not None: + self.best = self.baseline + else: + self.best = np.Inf if self.monitor_op == np.less else -np.Inf + + def on_epoch_end(self, steps, epoch, logs=None): + current = self.get_monitor_value(logs) + if current is None: + return + + if self.monitor_op(current - self.min_delta, self.best): + self.best = current + self.wait = 0 + else: + self.wait += 1 + if self.wait >= self.patience: + self.stopped_epoch = epoch + + def on_train_end(self, logs=None): + if self.stopped_epoch > 0 and self.verbose > 0: + print(f"Epoch {self.stopped_epoch+1}: early stopping\n") + + def get_monitor_value(self, logs): + monitor_value = logs.get(self.monitor) + if monitor_value is None: + warnings.warn( + "Early stopping conditioned on metric `%s` " + "which is not available. Available metrics are: %s" + % (self.monitor, ",".join(list(logs.keys()))), + RuntimeWarning, + ) + return monitor_value + + +def metric_mapping(metric, y_pred, y_true): + if metric == "accuracy": + if isinstance(y_pred, (list, tuple)): + y_pred = y_pred[0] + y_pred = torch.argmax(y_pred, dim=-1) + acc = torch.sum(y_pred.eq(y_true)).item() / y_true.size(0) + return acc + return None + + +def softmax(x, axis=-1): + x = x - x.max(axis=axis, keepdims=True) + x = np.exp(x) + return x / x.sum(axis=axis, keepdims=True) + + +class AutoRegressiveDecoder(object): + def __init__(self, start_id, end_id, maxlen, minlen=1, device="cpu"): + self.start_id = start_id + self.end_id = end_id + self.maxlen = maxlen + self.minlen = minlen + self.models = {} + self.device = device + if start_id is None: + self.first_output_ids = torch.empty((1, 0), dtype=int, device=device) + else: + self.first_output_ids = torch.tensor([[self.start_id]], device=device) + + @staticmethod + def wraps(default_rtype="probas", use_states=False): + def actual_decorator(predict): + def new_predict( + self, inputs, output_ids, states, temperature=1, rtype=default_rtype + ): + assert rtype in ["probas", "logits"] + prediction = predict(self, inputs, output_ids, states) + + if not use_states: + prediction = (prediction, None) + + if default_rtype == "logits": + prediction = ( + nn.Softmax(dim=-1)(prediction[0] / temperature), + prediction[1], + ) + elif temperature != 1: + probas = torch.power(prediction[0], 1.0 / temperature) + probas = probas / probas.sum(axis=-1, keepdims=True) + prediction = (probas, prediction[1]) + + if rtype == "probas": + return prediction + else: + return torch.log(prediction[0] + 1e-12), prediction[1] + + return new_predict + + return actual_decorator + + def predict(self, inputs, output_ids, states=None): + raise NotImplementedError + + def beam_search( + self, inputs_raw, topk, states=None, temperature=1, min_ends=1, add_btz_dim=True + ): + inputs = [] + for i in inputs_raw: + if isinstance(i, torch.torch.Tensor): + pass + elif isinstance(i, (list, tuple, np.ndarray)) and add_btz_dim: + i = torch.tensor([i], device=self.device) + elif isinstance(i, (list, tuple, np.ndarray)) and not add_btz_dim: + i = torch.tensor(i, device=self.device) + else: + raise ValueError( + "Beam search inputs ele only support tensor、array、list、tuple" + ) + inputs.append(i) + + output_ids, output_scores = self.first_output_ids, torch.zeros( + 1, device=self.device + ) + for step in range(self.maxlen): + scores, states = self.predict( + inputs, output_ids, states, temperature, "logits" + ) + if step == 0: + inputs = [i.repeat([topk] + [1] * (len(i.shape) - 1)) for i in inputs] + scores = output_scores.reshape((-1, 1)) + scores + indices = scores.flatten().argsort(dim=-1, descending=True)[:topk] + indices_1 = torch.div( + indices, scores.shape[1], rounding_mode="trunc" + ) + indices_2 = (indices % scores.shape[1]).reshape((-1, 1)) + output_ids = torch.cat([output_ids[indices_1], indices_2], 1) + output_scores = take_along_dim(scores, indices, dim=None) + is_end = output_ids[:, -1] == self.end_id + end_counts = (output_ids == self.end_id).sum(1) + if output_ids.shape[1] >= self.minlen: + best = output_scores.argmax() + if is_end[best] and end_counts[best] >= min_ends: + return output_ids[best] + else: + flag = ~is_end | (end_counts < min_ends) + if not flag.all(): + inputs = [i[flag] for i in inputs] + output_ids = output_ids[flag] + output_scores = output_scores[flag] + end_counts = end_counts[flag] + topk = flag.sum() + return output_ids[output_scores.argmax()] + + def random_sample( + self, inputs, n, topk=None, topp=None, states=None, temperature=1, min_ends=1 + ): + inputs = [torch.tensor([i], device=self.device) for i in inputs] + output_ids = self.first_output_ids + results = [] + for step in range(self.maxlen): + probas, states = self.predict( + inputs, output_ids, states, temperature, "probas" + ) + probas /= probas.sum(dim=-1, keepdims=True) + if step == 0: + probas = probas.repeat([n] + [1] * (len(probas.shape) - 1)) + inputs = [i.repeat([n] + [1] * (len(i.shape) - 1)) for i in inputs] + output_ids = output_ids.repeat([n] + [1] * (len(output_ids.shape) - 1)) + if topk is not None: + k_indices = probas.argsort(dim=-1, descending=True)[:, :topk] + probas = take_along_dim(probas, k_indices, dim=1) + probas /= probas.sum(dim=1, keepdims=True) + if topp is not None: + p_indices = probas.argsort(dim=-1, descending=True) + probas = take_along_dim(probas, p_indices, dim=-1) + cumsum_probas = torch.cumsum(probas, dim=-1) + flag = torch.roll(cumsum_probas >= topp, 1, dims=1) + flag[:, 0] = False + probas[flag] = 0 + probas /= probas.sum(dim=1, keepdims=True) + + sample_func = lambda p: torch.multinomial(p, 1) + sample_ids = torch.stack([sample_func(p) for p in probas]) + sample_ids = sample_ids.reshape((-1, 1)) + if topp is not None: + sample_ids = take_along_dim(p_indices, sample_ids, dim=1) + if topk is not None: + sample_ids = take_along_dim(k_indices, sample_ids, dim=1) + output_ids = torch.cat([output_ids, sample_ids], 1) + is_end = output_ids[:, -1] == self.end_id + end_counts = (output_ids == self.end_id).sum(1) + if output_ids.shape[1] >= self.minlen: + flag = is_end & (end_counts >= min_ends) + if flag.any(): + for ids in output_ids[flag]: + results.append(ids) + flag = flag == False + inputs = [i[flag] for i in inputs] + output_ids = output_ids[flag] + end_counts = end_counts[flag] + if len(output_ids) == 0: + break + for ids in output_ids: + results.append(ids) + return results + + +def search_layer(model, layer_name, retrun_first=True): + return_list = [] + for name, param in model.named_parameters(): + if param.requires_grad and layer_name in name: + return_list.append(param) + if len(return_list) == 0: + return None + if retrun_first: + return return_list[0] + else: + return return_list + + +class ListDataset(Dataset): + def __init__(self, file_path=None, data=None, **kwargs): + self.kwargs = kwargs + if isinstance(file_path, (str, list)): + self.data = self.load_data(file_path) + elif isinstance(data, list): + self.data = data + else: + raise ValueError( + "The input args shall be str format file_path / list format dataset" + ) + + def __len__(self): + return len(self.data) + + def __getitem__(self, index): + return self.data[index] + + @staticmethod + def load_data(file_path): + return file_path + + +class IterDataset(IterableDataset): + """流式读取文件""" + + def __init__(self, file_path=None, **kwargs): + self.kwargs = kwargs + if isinstance(file_path, (str, list)): + self.file_path = file_path + else: + raise ValueError( + "The input args shall be str format file_path / list format dataset" + ) + + def __iter__(self): + return self.load_data(self.file_path) + + @staticmethod + def load_data(file_path): + return file_path + + +# sinusoid编码 +def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None): + """Returns: [seq_len, d_hid]""" + position = torch.arange(0, n_position, dtype=torch.float).unsqueeze(1) + div_term = torch.exp( + torch.arange(0, d_hid, 2).float() * (-math.log(10000.0) / d_hid) + ) + embeddings_table = torch.zeros(n_position, d_hid) + embeddings_table[:, 0::2] = torch.sin(position * div_term) + embeddings_table[:, 1::2] = torch.cos(position * div_term) + return embeddings_table + +def cal_ts_num(tensor_shape): + cal_num = 0 + for obj in gc.get_objects(): + try: + if torch.is_tensor( + obj + ): # or (hasattr(obj, 'data') and torch.is_tensor(obj.data)): + tensor = obj + else: + continue + if tensor.is_cuda and tensor.size() == tensor_shape: + print(tensor.shape) + cal_num += 1 + except Exception as e: + print("A trivial exception occured: {}".format(e)) + print(cal_num) + + +def get_kw(cls, kwargs): + kwargs_new = {} + for k in kwargs: + if k not in set(inspect.getargspec(cls)[0]): + kwargs_new[k] = kwargs[k] + return kwargs_new + + +class FGM: + def __init__(self, model): + self.model = model + self.backup = {} + + def attack(self, epsilon=1.0, emb_name="word_embeddings", **kwargs): + for name, param in self.model.named_parameters(): + if param.requires_grad and emb_name in name: + self.backup[name] = param.data.clone() + norm = torch.norm(param.grad) + if norm != 0 and not torch.isnan(norm): + r_at = epsilon * param.grad / norm + param.data.add_(r_at) + + def restore(self, emb_name="emb", **kwargs): + for name, param in self.model.named_parameters(): + if param.requires_grad and emb_name in name: + assert name in self.backup + param.data = self.backup[name] + self.backup = {} + +class PGD: + def __init__(self, model): + self.model = model + self.emb_backup = {} + self.grad_backup = {} + + def attack( + self, + epsilon=1.0, + alpha=0.3, + emb_name="word_embeddings", + is_first_attack=False, + **kwargs, + ): + for name, param in self.model.named_parameters(): + if param.requires_grad and emb_name in name: + if is_first_attack: + self.emb_backup[name] = param.data.clone() + norm = torch.norm(param.grad) + if norm != 0 and not torch.isnan(norm): + r_at = alpha * param.grad / norm + param.data.add_(r_at) + param.data = self.project(name, param.data, epsilon) + + def restore(self, emb_name="emb", **kwargs): + for name, param in self.model.named_parameters(): + if param.requires_grad and emb_name in name: + assert name in self.emb_backup + param.data = self.emb_backup[name] + self.emb_backup = {} + + def project(self, param_name, param_data, epsilon): + r = param_data - self.emb_backup[param_name] + if torch.norm(r) > epsilon: + r = epsilon * r / torch.norm(r) + return self.emb_backup[param_name] + r + + def backup_grad(self): + for name, param in self.model.named_parameters(): + if param.requires_grad and (param.grad is not None): + self.grad_backup[name] = param.grad.clone() + + def restore_grad(self): + for name, param in self.model.named_parameters(): + if param.requires_grad and (param.grad is not None): + param.grad = self.grad_backup[name] + + +class VAT: + def __init__( + self, + model, + emb_name="word_embeddings", + noise_var=1e-5, + noise_gamma=1e-6, + adv_step_size=1e-3, + adv_alpha=1, + norm_type="l2", + **kwargs, + ): + self.model = model + self.noise_var = noise_var + self.noise_gamma = noise_gamma + self.adv_step_size = adv_step_size + self.adv_alpha = adv_alpha + self.norm_type = norm_type + self.embed = None + for (name, module) in self.model.named_modules(): + if emb_name in name: + module.register_forward_hook(hook=self.hook) + + def hook(self, module, fea_in, fea_out): + self.embed = fea_out + return None + + def forward_(self, train_X, new_embed): + if isinstance(train_X, (tuple, list)): + new_train_X = [new_embed] + train_X[1:] + adv_output = ( + self.model.forward(*new_train_X) + if self.model.forward.__code__.co_argcount >= 3 + else self.model.forward(new_train_X) + ) + elif isinstance(train_X, torch.Tensor): + adv_output = self.model.forward(new_embed) + return adv_output + + def virtual_adversarial_training(self, train_X, logits): + noise = self.embed.data.new(self.embed.size()).normal_(0, 1) * self.noise_var + noise.requires_grad_() + # x + r + new_embed = self.embed.data.detach() + noise + adv_output = self.forward_(train_X, new_embed) # forward第一次 + adv_logits = ( + adv_output[0] if isinstance(adv_output, (list, tuple)) else adv_output + ) + adv_loss = self.kl(adv_logits, logits.detach(), reduction="batchmean") + (delta_grad,) = torch.autograd.grad(adv_loss, noise, only_inputs=True) + norm = delta_grad.norm() + if torch.isnan(norm) or torch.isinf(norm): + return None + # inner sum + noise = noise + delta_grad * self.adv_step_size + # projection + noise = self.adv_project(noise, norm_type=self.norm_type, eps=self.noise_gamma) + new_embed = self.embed.data.detach() + noise + new_embed = new_embed.detach() + adv_output = self.forward_(train_X, new_embed) + adv_logits = ( + adv_output[0] if isinstance(adv_output, (list, tuple)) else adv_output + ) + adv_loss_f = self.kl(adv_logits, logits.detach()) + adv_loss_b = self.kl(logits, adv_logits.detach()) + adv_loss = (adv_loss_f + adv_loss_b) * self.adv_alpha + return adv_loss + + @staticmethod + def kl(inputs, targets, reduction="sum"): + loss = F.kl_div( + F.log_softmax(inputs, dim=-1), + F.softmax(targets, dim=-1), + reduction=reduction, + ) + return loss + + @staticmethod + def adv_project(grad, norm_type="inf", eps=1e-6): + if norm_type == "l2": + direction = grad / (torch.norm(grad, dim=-1, keepdim=True) + eps) + elif norm_type == "l1": + direction = grad.sign() + else: + direction = grad / (grad.abs().max(-1, keepdim=True)[0] + eps) + return direction + + +class WebServing(object): + def __init__(self, host="0.0.0.0", port=8000, server="paste"): + + import bottle + + self.host = host + self.port = port + self.server = server + self.bottle = bottle + + def wraps(self, func, arguments, method="GET"): + def new_func(): + outputs = {"code": 0, "desc": "succeeded", "data": {}} + kwargs = {} + for key, value in arguments.items(): + if method == "GET": + result = self.bottle.request.GET.getunicode(key) + else: + result = self.bottle.request.POST.getunicode(key) + if result is None: + if value[1]: + outputs["code"] = 1 + outputs["desc"] = 'lack of "%s" argument' % key + return json.dumps(outputs, ensure_ascii=False) + else: + if value[0] is not None: + result = value[0](result) + kwargs[key] = result + try: + outputs["data"] = func(**kwargs) + except Exception as e: + outputs["code"] = 2 + outputs["desc"] = str(e) + return json.dumps(outputs, ensure_ascii=False) + + return new_func + + def route(self, path, func, arguments, method="GET"): + func = self.wraps(func, arguments, method) + self.bottle.route(path, method=method)(func) + + def start(self): + self.bottle.run(host=self.host, port=self.port, server=self.server) + + +def get_pool_emb( + hidden_state=None, + pooler=None, + attention_mask=None, + pool_strategy="cls", + custom_layer=None, +): + if pool_strategy == "pooler": + return pooler + elif pool_strategy == "cls": + if isinstance(hidden_state, (list, tuple)): + hidden_state = hidden_state[-1] + assert isinstance( + hidden_state, torch.Tensor + ), f"{pool_strategy} strategy request tensor hidden_state" + return hidden_state[:, 0] + elif pool_strategy in {"last-avg", "mean"}: + if isinstance(hidden_state, (list, tuple)): + hidden_state = hidden_state[-1] + assert isinstance( + hidden_state, torch.Tensor + ), f"{pool_strategy} pooling strategy request tensor hidden_state" + hid = torch.sum(hidden_state * attention_mask[:, :, None], dim=1) + attention_mask = torch.sum(attention_mask, dim=1)[:, None] + return hid / attention_mask + elif pool_strategy in {"last-max", "max"}: + if isinstance(hidden_state, (list, tuple)): + hidden_state = hidden_state[-1] + assert isinstance( + hidden_state, torch.Tensor + ), f"{pool_strategy} pooling strategy request tensor hidden_state" + hid = hidden_state * attention_mask[:, :, None] + return torch.max(hid, dim=1) + elif pool_strategy == "first-last-avg": + assert isinstance( + hidden_state, list + ), f"{pool_strategy} pooling strategy request list hidden_state" + hid = torch.sum(hidden_state[1] * attention_mask[:, :, None], dim=1) + hid += torch.sum(hidden_state[-1] * attention_mask[:, :, None], dim=1) + attention_mask = torch.sum(attention_mask, dim=1)[:, None] + return hid / (2 * attention_mask) + elif pool_strategy == "custom": + assert isinstance( + hidden_state, list + ), f"{pool_strategy} pooling strategy request list hidden_state" + assert isinstance( + custom_layer, (int, list, tuple) + ), f"{pool_strategy} pooling strategy request int/list/tuple custom_layer" + custom_layer = [custom_layer] if isinstance(custom_layer, int) else custom_layer + hid = 0 + for i, layer in enumerate(custom_layer, start=1): + hid += torch.sum(hidden_state[layer] * attention_mask[:, :, None], dim=1) + attention_mask = torch.sum(attention_mask, dim=1)[:, None] + return hid / (i * attention_mask) + else: + raise ValueError("pool_strategy illegal") + + +def seed_everything(seed=None): + max_seed_value = np.iinfo(np.uint32).max + min_seed_value = np.iinfo(np.uint32).min + + if (seed is None) or not (min_seed_value <= seed <= max_seed_value): + random.randint(np.iinfo(np.uint32).min, np.iinfo(np.uint32).max) + os.environ["PYTHONHASHSEED"] = str(seed) + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + return seed diff --git a/models/nlp/plm/bert_base_ner/igie/Int8QAT/run_qat.py b/models/nlp/plm/bert_base_ner/igie/Int8QAT/run_qat.py index 707179359c34ab48ac39f1f1377ee20d08d4d91a..483aac50ace3f3e08a33327a9f8a838528aefa99 100644 --- a/models/nlp/plm/bert_base_ner/igie/Int8QAT/run_qat.py +++ b/models/nlp/plm/bert_base_ner/igie/Int8QAT/run_qat.py @@ -1,283 +1,283 @@ -# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. -# All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); you may -# not use this file except in compliance with the License. You may obtain -# a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -# License for the specific language governing permissions and limitations -# under the License. - -import json -import os -import sys - -import numpy as np -import torch -import torch.nn as nn -import torch.optim as optim -from torch.utils.data import DataLoader -from tqdm import tqdm -import argparse - -from bert4torch.layers import CRF -from bert4torch.models import BaseModel, build_transformer_model -from bert4torch.snippets import Callback, ListDataset, seed_everything, sequence_padding -from bert4torch.tokenizers import Tokenizer -from ls_hf_transformer_layer import inject_ls_layer - -maxlen = 256 -batch_size = 16 -categories = ["O", "B-LOC", "I-LOC", "B-PER", "I-PER", "B-ORG", "I-ORG"] -categories_id2label = {i: k for i, k in enumerate(categories)} -categories_label2id = {k: i for i, k in enumerate(categories)} - -def parse_args(): - parser = argparse.ArgumentParser() - - parser.add_argument("--model_dir", - type=str, - required=True, - help="pytorch weights dir.") - - parser.add_argument("--datasets_dir", - type=str, - required=True, - help="datasets dir.") - - args = parser.parse_args() - - return args - -class quan_model_config: - module_type = 2 - quant_mode = "qat" - enable_quant = True - - -class quan_train_config: - fp16 = False - local_rank = -1 - - -class quant_bert_config: - pass - -class MyDataset(ListDataset): - @staticmethod - def load_data(filename): - D = [] - with open(filename, encoding="utf-8") as f: - f = f.read() - for l in f.split("\n\n"): - if not l: - continue - d = [""] - for i, c in enumerate(l.split("\n")): - char, flag = c.split(" ") - d[0] += char - if flag[0] == "B": - d.append([i, i, flag[2:]]) - elif flag[0] == "I": - d[-1][1] = i - D.append(d) - return D - -class Model(BaseModel): - def __init__(self): - super().__init__() - self.bert = build_transformer_model( - config_path=config_path, - checkpoint_path=checkpoint_path, - segment_vocab_size=0, - ) - self.fc = nn.Linear(768, len(categories)) - self.crf = CRF(len(categories)) - - def forward(self, token_ids): - sequence_output = self.bert([token_ids]) # [btz, seq_len, hdsz] - emission_score = self.fc(sequence_output) # [btz, seq_len, tag_size] - attention_mask = token_ids.gt(0).long() - return emission_score, attention_mask - - def predict(self, token_ids): - self.eval() - with torch.no_grad(): - emission_score, attention_mask = self.forward(token_ids) - best_path = self.crf.decode( - emission_score, attention_mask - ) # [btz, seq_len] - return best_path - -class Loss(nn.Module): - def forward(self, outputs, labels): - return model.crf(*outputs, labels) - -def collate_fn(batch): - batch_token_ids, batch_labels = [], [] - for d in batch: - tokens = tokenizer.tokenize(d[0], maxlen=maxlen) - mapping = tokenizer.rematch(d[0], tokens) - start_mapping = {j[0]: i for i, j in enumerate(mapping) if j} - end_mapping = {j[-1]: i for i, j in enumerate(mapping) if j} - token_ids = tokenizer.tokens_to_ids(tokens) - labels = np.zeros(len(token_ids)) - for start, end, label in d[1:]: - if start in start_mapping and end in end_mapping: - start = start_mapping[start] - end = end_mapping[end] - labels[start] = categories_label2id["B-" + label] - labels[start + 1 : end + 1] = categories_label2id["I-" + label] - batch_token_ids.append(token_ids) - batch_labels.append(labels) - batch_token_ids = torch.tensor( - sequence_padding(batch_token_ids), dtype=torch.long, device=device - ) - batch_labels = torch.tensor( - sequence_padding(batch_labels), dtype=torch.long, device=device - ) - return batch_token_ids, batch_labels - -def trans_entity2tuple(scores): - batch_entity_ids = set() - for i, one_samp in enumerate(scores): - entity_ids = [] - for j, item in enumerate(one_samp): - flag_tag = categories_id2label[item.item()] - if flag_tag.startswith("B-"): # B - entity_ids.append([i, j, j, flag_tag[2:]]) - elif len(entity_ids) == 0: - continue - elif ( - (len(entity_ids[-1]) > 0) - and flag_tag.startswith("I-") - and (flag_tag[2:] == entity_ids[-1][-1]) - ): # I - entity_ids[-1][-2] = j - elif len(entity_ids[-1]) > 0: - entity_ids.append([]) - - for i in entity_ids: - if i: - batch_entity_ids.add(tuple(i)) - return batch_entity_ids - -def evaluate(data): - X, Y, Z = 1e-10, 1e-10, 1e-10 - X2, Y2, Z2 = 1e-10, 1e-10, 1e-10 - for token_ids, label in tqdm(data): - scores = model.predict(token_ids) - attention_mask = label.gt(0) - - # token粒度 - X += (scores.eq(label) * attention_mask).sum().item() - Y += scores.gt(0).sum().item() - Z += label.gt(0).sum().item() - - # entity粒度 - entity_pred = trans_entity2tuple(scores) - entity_true = trans_entity2tuple(label) - X2 += len(entity_pred.intersection(entity_true)) - Y2 += len(entity_pred) - Z2 += len(entity_true) - f1, precision, recall = 2 * X / (Y + Z), X / Y, X / Z - f2, precision2, recall2 = 2 * X2 / (Y2 + Z2), X2 / Y2, X2 / Z2 - return f1, precision, recall, f2, precision2, recall2 - -class Evaluator(Callback): - def __init__(self): - self.best_val_f1 = 0.0 - - def on_epoch_end(self, steps, epoch, logs=None): - f1, precision, recall, f2, precision2, recall2 = evaluate(valid_dataloader) - if f2 > self.best_val_f1: - self.best_val_f1 = f2 - print(f"[val-token level] f1: {f1:.5f}, p: {precision:.5f} r: {recall:.5f}") - print( - f"[val-entity level] f1: {f2:.5f}, p: {precision2:.5f} r: {recall2:.5f} best_f1: {self.best_val_f1:.5f}\n" - ) - - -if __name__ == "__main__": - - args = parse_args() - - model_dir = args.model_dir - data_dir = args.datasets_dir - - config_path = os.path.join(model_dir, "config.json") - checkpoint_path = os.path.join(model_dir, "pytorch_model.bin") - dict_path = os.path.join(model_dir, "vocab.txt") - - if not os.path.isfile(checkpoint_path): - print("cant found checkpoint_path: {}".format(checkpoint_path)) - assert os.path.isfile(checkpoint_path) - - - if not os.path.isfile(config_path): - print("cant found config path: {}".format(config_path)) - assert os.path.isfile(config_path) - - - if not os.path.isfile(dict_path): - print("cant found dict_path: {}".format(dict_path)) - assert os.path.isfile(dict_path) - - device = "cuda" if torch.cuda.is_available() else "cpu" - - seed_everything(42) - - tokenizer = Tokenizer(dict_path, do_lower_case=True) - - train_data_file = os.path.join(data_dir, "example.train") - dev_data_file = os.path.join(data_dir, "example.dev") - - if not os.path.isfile(train_data_file): - print("cant found train data file: {}".format(train_data_file)) - assert os.path.isfile(train_data_file) - - if not os.path.isfile(dev_data_file): - print("cant found dev data file: {}".format(dev_data_file)) - assert os.path.isfile(dev_data_file) - - train_dataloader = DataLoader( - MyDataset(train_data_file), - batch_size=batch_size, - shuffle=True, - collate_fn=collate_fn, - ) - - valid_dataloader = DataLoader( - MyDataset(dev_data_file), batch_size=batch_size, collate_fn=collate_fn - ) - - quant_training_args = quan_train_config() - quant_model_args = quan_model_config() - quant_bert_args = quant_bert_config() - with open(config_path, "r") as f: - data = json.load(f) - quant_bert_args.__dict__.update(data) - - model = Model() - inject_ls_layer(model, quant_training_args, quant_model_args, quant_bert_args) - - model.to(device) - - model.compile(loss=Loss(), optimizer=optim.Adam(model.parameters(), lr=2e-5)) - - evaluator = Evaluator() - - model.fit(train_dataloader, epochs=3, steps_per_epoch=None, callbacks=[evaluator]) - - quant_dir = "quant_base/" - - if not os.path.isdir(quant_dir): - os.makedirs(quant_dir) - save_file = os.path.join(quant_dir, "pytorch_model.bin") - - model.save_weights(save_file) +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import json +import os +import sys + +import numpy as np +import torch +import torch.nn as nn +import torch.optim as optim +from torch.utils.data import DataLoader +from tqdm import tqdm +import argparse + +from bert4torch.layers import CRF +from bert4torch.models import BaseModel, build_transformer_model +from bert4torch.snippets import Callback, ListDataset, seed_everything, sequence_padding +from bert4torch.tokenizers import Tokenizer +from ls_hf_transformer_layer import inject_ls_layer + +maxlen = 256 +batch_size = 16 +categories = ["O", "B-LOC", "I-LOC", "B-PER", "I-PER", "B-ORG", "I-ORG"] +categories_id2label = {i: k for i, k in enumerate(categories)} +categories_label2id = {k: i for i, k in enumerate(categories)} + +def parse_args(): + parser = argparse.ArgumentParser() + + parser.add_argument("--model_dir", + type=str, + required=True, + help="pytorch weights dir.") + + parser.add_argument("--datasets_dir", + type=str, + required=True, + help="datasets dir.") + + args = parser.parse_args() + + return args + +class quan_model_config: + module_type = 2 + quant_mode = "qat" + enable_quant = True + + +class quan_train_config: + fp16 = False + local_rank = -1 + + +class quant_bert_config: + pass + +class MyDataset(ListDataset): + @staticmethod + def load_data(filename): + D = [] + with open(filename, encoding="utf-8") as f: + f = f.read() + for l in f.split("\n\n"): + if not l: + continue + d = [""] + for i, c in enumerate(l.split("\n")): + char, flag = c.split(" ") + d[0] += char + if flag[0] == "B": + d.append([i, i, flag[2:]]) + elif flag[0] == "I": + d[-1][1] = i + D.append(d) + return D + +class Model(BaseModel): + def __init__(self): + super().__init__() + self.bert = build_transformer_model( + config_path=config_path, + checkpoint_path=checkpoint_path, + segment_vocab_size=0, + ) + self.fc = nn.Linear(768, len(categories)) + self.crf = CRF(len(categories)) + + def forward(self, token_ids): + sequence_output = self.bert([token_ids]) # [btz, seq_len, hdsz] + emission_score = self.fc(sequence_output) # [btz, seq_len, tag_size] + attention_mask = token_ids.gt(0).long() + return emission_score, attention_mask + + def predict(self, token_ids): + self.eval() + with torch.no_grad(): + emission_score, attention_mask = self.forward(token_ids) + best_path = self.crf.decode( + emission_score, attention_mask + ) # [btz, seq_len] + return best_path + +class Loss(nn.Module): + def forward(self, outputs, labels): + return model.crf(*outputs, labels) + +def collate_fn(batch): + batch_token_ids, batch_labels = [], [] + for d in batch: + tokens = tokenizer.tokenize(d[0], maxlen=maxlen) + mapping = tokenizer.rematch(d[0], tokens) + start_mapping = {j[0]: i for i, j in enumerate(mapping) if j} + end_mapping = {j[-1]: i for i, j in enumerate(mapping) if j} + token_ids = tokenizer.tokens_to_ids(tokens) + labels = np.zeros(len(token_ids)) + for start, end, label in d[1:]: + if start in start_mapping and end in end_mapping: + start = start_mapping[start] + end = end_mapping[end] + labels[start] = categories_label2id["B-" + label] + labels[start + 1 : end + 1] = categories_label2id["I-" + label] + batch_token_ids.append(token_ids) + batch_labels.append(labels) + batch_token_ids = torch.tensor( + sequence_padding(batch_token_ids), dtype=torch.long, device=device + ) + batch_labels = torch.tensor( + sequence_padding(batch_labels), dtype=torch.long, device=device + ) + return batch_token_ids, batch_labels + +def trans_entity2tuple(scores): + batch_entity_ids = set() + for i, one_samp in enumerate(scores): + entity_ids = [] + for j, item in enumerate(one_samp): + flag_tag = categories_id2label[item.item()] + if flag_tag.startswith("B-"): # B + entity_ids.append([i, j, j, flag_tag[2:]]) + elif len(entity_ids) == 0: + continue + elif ( + (len(entity_ids[-1]) > 0) + and flag_tag.startswith("I-") + and (flag_tag[2:] == entity_ids[-1][-1]) + ): # I + entity_ids[-1][-2] = j + elif len(entity_ids[-1]) > 0: + entity_ids.append([]) + + for i in entity_ids: + if i: + batch_entity_ids.add(tuple(i)) + return batch_entity_ids + +def evaluate(data): + X, Y, Z = 1e-10, 1e-10, 1e-10 + X2, Y2, Z2 = 1e-10, 1e-10, 1e-10 + for token_ids, label in tqdm(data): + scores = model.predict(token_ids) + attention_mask = label.gt(0) + + # token粒度 + X += (scores.eq(label) * attention_mask).sum().item() + Y += scores.gt(0).sum().item() + Z += label.gt(0).sum().item() + + # entity粒度 + entity_pred = trans_entity2tuple(scores) + entity_true = trans_entity2tuple(label) + X2 += len(entity_pred.intersection(entity_true)) + Y2 += len(entity_pred) + Z2 += len(entity_true) + f1, precision, recall = 2 * X / (Y + Z), X / Y, X / Z + f2, precision2, recall2 = 2 * X2 / (Y2 + Z2), X2 / Y2, X2 / Z2 + return f1, precision, recall, f2, precision2, recall2 + +class Evaluator(Callback): + def __init__(self): + self.best_val_f1 = 0.0 + + def on_epoch_end(self, steps, epoch, logs=None): + f1, precision, recall, f2, precision2, recall2 = evaluate(valid_dataloader) + if f2 > self.best_val_f1: + self.best_val_f1 = f2 + print(f"[val-token level] f1: {f1:.5f}, p: {precision:.5f} r: {recall:.5f}") + print( + f"[val-entity level] f1: {f2:.5f}, p: {precision2:.5f} r: {recall2:.5f} best_f1: {self.best_val_f1:.5f}\n" + ) + + +if __name__ == "__main__": + + args = parse_args() + + model_dir = args.model_dir + data_dir = args.datasets_dir + + config_path = os.path.join(model_dir, "config.json") + checkpoint_path = os.path.join(model_dir, "pytorch_model.bin") + dict_path = os.path.join(model_dir, "vocab.txt") + + if not os.path.isfile(checkpoint_path): + print("cant found checkpoint_path: {}".format(checkpoint_path)) + assert os.path.isfile(checkpoint_path) + + + if not os.path.isfile(config_path): + print("cant found config path: {}".format(config_path)) + assert os.path.isfile(config_path) + + + if not os.path.isfile(dict_path): + print("cant found dict_path: {}".format(dict_path)) + assert os.path.isfile(dict_path) + + device = "cuda" if torch.cuda.is_available() else "cpu" + + seed_everything(42) + + tokenizer = Tokenizer(dict_path, do_lower_case=True) + + train_data_file = os.path.join(data_dir, "example.train") + dev_data_file = os.path.join(data_dir, "example.dev") + + if not os.path.isfile(train_data_file): + print("cant found train data file: {}".format(train_data_file)) + assert os.path.isfile(train_data_file) + + if not os.path.isfile(dev_data_file): + print("cant found dev data file: {}".format(dev_data_file)) + assert os.path.isfile(dev_data_file) + + train_dataloader = DataLoader( + MyDataset(train_data_file), + batch_size=batch_size, + shuffle=True, + collate_fn=collate_fn, + ) + + valid_dataloader = DataLoader( + MyDataset(dev_data_file), batch_size=batch_size, collate_fn=collate_fn + ) + + quant_training_args = quan_train_config() + quant_model_args = quan_model_config() + quant_bert_args = quant_bert_config() + with open(config_path, "r") as f: + data = json.load(f) + quant_bert_args.__dict__.update(data) + + model = Model() + inject_ls_layer(model, quant_training_args, quant_model_args, quant_bert_args) + + model.to(device) + + model.compile(loss=Loss(), optimizer=optim.Adam(model.parameters(), lr=2e-5)) + + evaluator = Evaluator() + + model.fit(train_dataloader, epochs=3, steps_per_epoch=None, callbacks=[evaluator]) + + quant_dir = "quant_base/" + + if not os.path.isdir(quant_dir): + os.makedirs(quant_dir) + save_file = os.path.join(quant_dir, "pytorch_model.bin") + + model.save_weights(save_file) diff --git a/models/nlp/plm/bert_base_ner/igie/requirements.txt b/models/nlp/plm/bert_base_ner/igie/requirements.txt index f623aa95796ba66821f14ef560fcf42dc98ef2dc..eb25b6a06a709f1fd326c97ac448d23af2ce7d52 100644 --- a/models/nlp/plm/bert_base_ner/igie/requirements.txt +++ b/models/nlp/plm/bert_base_ner/igie/requirements.txt @@ -1,5 +1,5 @@ -onnx -tqdm -transformers -bert4torch -numpy==1.23.5 +onnx +tqdm +transformers +bert4torch +numpy==1.23.5 diff --git a/models/nlp/plm/bert_base_squad/igie/requirements.txt b/models/nlp/plm/bert_base_squad/igie/requirements.txt index a056621131f34d34d84e877ca6f587a5cb490bac..4b4dc40ab9d909063ccc6e3dfe215a6a9f479b73 100644 --- a/models/nlp/plm/bert_base_squad/igie/requirements.txt +++ b/models/nlp/plm/bert_base_squad/igie/requirements.txt @@ -1,3 +1,3 @@ -onnx -tqdm -transformers==4.37.1 +onnx +tqdm +transformers==4.37.1 diff --git a/models/nlp/plm/bert_large_squad/igie/requirements.txt b/models/nlp/plm/bert_large_squad/igie/requirements.txt index e2b0e79ceecdfe8b4ca7bbf1ca67753069158d18..7543264ea4b0f7cb62b1a8567944d4abab37a2c9 100644 --- a/models/nlp/plm/bert_large_squad/igie/requirements.txt +++ b/models/nlp/plm/bert_large_squad/igie/requirements.txt @@ -1,4 +1,4 @@ -onnx -tqdm -numpy==1.23.5 -transformers==4.37.1 +onnx +tqdm +numpy==1.23.5 +transformers==4.37.1 diff --git a/toolbox/ByteMLPerf/byte_infer_perf/general_perf/backends/ILUVATAR/utils/onnx_rewrite_batch_size.py b/toolbox/ByteMLPerf/byte_infer_perf/general_perf/backends/ILUVATAR/utils/onnx_rewrite_batch_size.py index 5332febfb8f2ce169bafaf9c74683814f07ac8b0..4eab3b87644f42fa488c06248f8ad64c5cb498f9 100644 --- a/toolbox/ByteMLPerf/byte_infer_perf/general_perf/backends/ILUVATAR/utils/onnx_rewrite_batch_size.py +++ b/toolbox/ByteMLPerf/byte_infer_perf/general_perf/backends/ILUVATAR/utils/onnx_rewrite_batch_size.py @@ -1,113 +1,113 @@ -""" -rewrite src onnx model and infer shape if possible, current sypport - -1. rewrite batch_size, e.g 1x3x640x640 -> 32x3x640x640 - -Attention: -1. all inputs/outputs batchszie dim will be modified together, which means some NLP/Audio senquence models will introduce problems - - -""" -import onnx -from onnx import OperatorSetIdProto -import onnx.numpy_helper - -import onnxoptimizer -from onnxsim import simplify - -from .onnx_util import get_batch_size, rewrite_tensor_batch_size - -def rewrite_batch_size(model, - batch_size, - modify_reshape_dim=True, - save_model_path=None): - - ## rewrite input and output - if isinstance(model, str): - model = onnx.load(model) - - - ## there is a issue that when the onnx model comes from tf, - ## some shape info is stored as constant node's output instead of initializer - passes = [ - "extract_constant_to_initializer", "eliminate_unused_initializer" - ] - model = onnxoptimizer.optimize(model, passes) - - - - # to support qlinear op if the opset_import is not supported - # if we have some ohter domains need to import, add them here - ms_opset = OperatorSetIdProto() - ms_opset.domain = "com.microsoft" - ms_opset.version = 1 - - ori_opset_import = model.opset_import - - if ms_opset not in ori_opset_import: - ori_opset_import.append(ms_opset) - - model, check = simplify(model) - assert check, "Simplified ONNX model could not be validated" - - - graph = model.graph - initializer = graph.initializer - inputs = graph.input - outputs = graph.output - nodes = graph.node - - ori_batch_size = get_batch_size(model) - - ## in case that some onnx model inputs contain initializers' shape info, we will remove them to avoid rewriting input failure - - initializer_names = set([i.name for i in initializer]) - import copy - tmp_inputs = copy.deepcopy(inputs) - for i in tmp_inputs: - if i.name in initializer_names: - inputs.remove(i) - - for i in inputs: - rewrite_tensor_batch_size(i, batch_size) - - for i in outputs: - rewrite_tensor_batch_size(i, batch_size) - - ## we may need to modify reshape initializer if we modify input batchsize - ## this code only works when the target shape is fixed, and occurs as a input initializer in the node - ## so this may introduce some other problems when the purpose of reshape operations are totally different - - if modify_reshape_dim: - reshape_input = [] - for idx, i in enumerate(nodes): - if i.op_type == "Reshape": - reshape_input.extend(i.input) - if i.op_type == "Resize" and len(i.input) == 4: - reshape_input.append(i.input[3]) - for idx, i in enumerate(initializer): - if i.name in reshape_input: - shape = onnx.numpy_helper.to_array(i).copy() - if shape.dtype == "int64": - shape[0] = batch_size - initializer[idx].CopyFrom( - onnx.numpy_helper.from_array(shape, i.name)) - - for i in graph.value_info: - if i.type.tensor_type.shape.dim: - if i.type.tensor_type.shape.dim[0].dim_value == ori_batch_size: - i.type.tensor_type.shape.dim[0].dim_value = batch_size - - model, check = simplify(model) - assert check, "Simplified ONNX model could not be validated" - - model = onnx.shape_inference.infer_shapes(model, - check_type=True, - strict_mode=True, - data_prop=True) - onnx.checker.check_model(model) - - if save_model_path: - onnx.save(model, save_model_path) - return model - +""" +rewrite src onnx model and infer shape if possible, current sypport + +1. rewrite batch_size, e.g 1x3x640x640 -> 32x3x640x640 + +Attention: +1. all inputs/outputs batchszie dim will be modified together, which means some NLP/Audio senquence models will introduce problems + + +""" +import onnx +from onnx import OperatorSetIdProto +import onnx.numpy_helper + +import onnxoptimizer +from onnxsim import simplify + +from .onnx_util import get_batch_size, rewrite_tensor_batch_size + +def rewrite_batch_size(model, + batch_size, + modify_reshape_dim=True, + save_model_path=None): + + ## rewrite input and output + if isinstance(model, str): + model = onnx.load(model) + + + ## there is a issue that when the onnx model comes from tf, + ## some shape info is stored as constant node's output instead of initializer + passes = [ + "extract_constant_to_initializer", "eliminate_unused_initializer" + ] + model = onnxoptimizer.optimize(model, passes) + + + + # to support qlinear op if the opset_import is not supported + # if we have some ohter domains need to import, add them here + ms_opset = OperatorSetIdProto() + ms_opset.domain = "com.microsoft" + ms_opset.version = 1 + + ori_opset_import = model.opset_import + + if ms_opset not in ori_opset_import: + ori_opset_import.append(ms_opset) + + model, check = simplify(model) + assert check, "Simplified ONNX model could not be validated" + + + graph = model.graph + initializer = graph.initializer + inputs = graph.input + outputs = graph.output + nodes = graph.node + + ori_batch_size = get_batch_size(model) + + ## in case that some onnx model inputs contain initializers' shape info, we will remove them to avoid rewriting input failure + + initializer_names = set([i.name for i in initializer]) + import copy + tmp_inputs = copy.deepcopy(inputs) + for i in tmp_inputs: + if i.name in initializer_names: + inputs.remove(i) + + for i in inputs: + rewrite_tensor_batch_size(i, batch_size) + + for i in outputs: + rewrite_tensor_batch_size(i, batch_size) + + ## we may need to modify reshape initializer if we modify input batchsize + ## this code only works when the target shape is fixed, and occurs as a input initializer in the node + ## so this may introduce some other problems when the purpose of reshape operations are totally different + + if modify_reshape_dim: + reshape_input = [] + for idx, i in enumerate(nodes): + if i.op_type == "Reshape": + reshape_input.extend(i.input) + if i.op_type == "Resize" and len(i.input) == 4: + reshape_input.append(i.input[3]) + for idx, i in enumerate(initializer): + if i.name in reshape_input: + shape = onnx.numpy_helper.to_array(i).copy() + if shape.dtype == "int64": + shape[0] = batch_size + initializer[idx].CopyFrom( + onnx.numpy_helper.from_array(shape, i.name)) + + for i in graph.value_info: + if i.type.tensor_type.shape.dim: + if i.type.tensor_type.shape.dim[0].dim_value == ori_batch_size: + i.type.tensor_type.shape.dim[0].dim_value = batch_size + + model, check = simplify(model) + assert check, "Simplified ONNX model could not be validated" + + model = onnx.shape_inference.infer_shapes(model, + check_type=True, + strict_mode=True, + data_prop=True) + onnx.checker.check_model(model) + + if save_model_path: + onnx.save(model, save_model_path) + return model + diff --git a/toolbox/ByteMLPerf/byte_infer_perf/general_perf/backends/ILUVATAR/utils/onnx_util.py b/toolbox/ByteMLPerf/byte_infer_perf/general_perf/backends/ILUVATAR/utils/onnx_util.py index 96823647216acb23b8f1a6be39d00aacf53107ec..afe2aabf055a63fe7676d796aa5c891d51684169 100644 --- a/toolbox/ByteMLPerf/byte_infer_perf/general_perf/backends/ILUVATAR/utils/onnx_util.py +++ b/toolbox/ByteMLPerf/byte_infer_perf/general_perf/backends/ILUVATAR/utils/onnx_util.py @@ -1,130 +1,130 @@ -import onnx -from collections import defaultdict - -import onnx -import os - -## FYI -ONNX_DTYPE = { - 0: onnx.TensorProto.FLOAT, - 1: onnx.TensorProto.FLOAT, - 2: onnx.TensorProto.UINT8, - 3: onnx.TensorProto.INT8, - 4: onnx.TensorProto.UINT16, - 5: onnx.TensorProto.INT16, - 6: onnx.TensorProto.INT32, - 7: onnx.TensorProto.INT64, - 8: onnx.TensorProto.STRING, - 9: onnx.TensorProto.BOOL, -} - - -def rewrite_tensor_dim(tensor, dim_value_dict): - if isinstance(dim_value_dict, list): - dim_value_dict = {idx: i for idx, i in enumerate(dim_value_dict)} - all_dim = tensor.type.tensor_type.shape.dim - for idx, value in dim_value_dict.items(): - if isinstance(value, str): - all_dim[idx].dim_param = "batch" - else: - all_dim[idx].dim_value = value - - -def rewrite_tensor_batch_size(tensor, batch_size): - - dim_value_dict = {0: batch_size} - rewrite_tensor_dim(tensor, dim_value_dict) - - -def get_tensor_dim(tensor): - dims = [] - all_dim = tensor.type.tensor_type.shape.dim - rank = len(all_dim) - for i in range(rank): - if all_dim[i].dim_value: - dims.append(all_dim[i].dim_value) - else: - dims.append(all_dim[i].dim_param) - return dims - - -def get_tensor_name(tensor): - return tensor.name - - -def nchw_dim_to_nhwc_dim(dim_list): - assert len(dim_list) == 4 - new_dim = [dim_list[0], dim_list[2], dim_list[3], dim_list[1]] - return new_dim - - -def get_input_number(model): - if isinstance(model, str): - model = onnx.load(model) - inputs = model.graph.input - return len(inputs) - -def get_batch_size(model): - if isinstance(model, str): - model = onnx.load(model) - inputs = model.graph.input - return get_tensor_dim(inputs[0])[0] - - -def count_op_type(model): - if isinstance(model, str): - model = onnx.load(model) - - nodes = model.graph.node - - node2count = defaultdict(int) - for i in nodes: - node2count[i.op_type] += 1 - - return node2count - - -def contain_qlinear_opearator(onnx_model): - if isinstance(onnx_model, str): - onnx_model = onnx.load(onnx_model) - - nodes = onnx_model.graph.node - - for i in nodes: - op_type = i.op_type.lower() - if op_type.startswith("qlinear") or op_type.startswith("qgemm"): - return True - return False - - -def get_all_node_name(model, exclude_constant=False, pretty_print=False): - if isinstance(model, str): - model = onnx.load(model) - - nodes = model.graph.node - if exclude_constant: - all_node = [i.name for i in nodes if i.op_type != "Constant"] - else: - all_node = [i.name for i in nodes] - - all_node.sort() - if pretty_print: - res = [f'"{i}"' for i in all_node] - res = ",\n".join(res) - res = f'[\n{res}\n]' - print(res) - - return all_node - -def rewrite_int64_input_to_int32(model): - inputs = model.graph.input - - for i in inputs: - if i.type.tensor_type.elem_type == 7: - i.type.tensor_type.elem_type = 6 - - print(inputs) - import pdb;pdb.set_trace() - onnx.checker.check_model(model) - +import onnx +from collections import defaultdict + +import onnx +import os + +## FYI +ONNX_DTYPE = { + 0: onnx.TensorProto.FLOAT, + 1: onnx.TensorProto.FLOAT, + 2: onnx.TensorProto.UINT8, + 3: onnx.TensorProto.INT8, + 4: onnx.TensorProto.UINT16, + 5: onnx.TensorProto.INT16, + 6: onnx.TensorProto.INT32, + 7: onnx.TensorProto.INT64, + 8: onnx.TensorProto.STRING, + 9: onnx.TensorProto.BOOL, +} + + +def rewrite_tensor_dim(tensor, dim_value_dict): + if isinstance(dim_value_dict, list): + dim_value_dict = {idx: i for idx, i in enumerate(dim_value_dict)} + all_dim = tensor.type.tensor_type.shape.dim + for idx, value in dim_value_dict.items(): + if isinstance(value, str): + all_dim[idx].dim_param = "batch" + else: + all_dim[idx].dim_value = value + + +def rewrite_tensor_batch_size(tensor, batch_size): + + dim_value_dict = {0: batch_size} + rewrite_tensor_dim(tensor, dim_value_dict) + + +def get_tensor_dim(tensor): + dims = [] + all_dim = tensor.type.tensor_type.shape.dim + rank = len(all_dim) + for i in range(rank): + if all_dim[i].dim_value: + dims.append(all_dim[i].dim_value) + else: + dims.append(all_dim[i].dim_param) + return dims + + +def get_tensor_name(tensor): + return tensor.name + + +def nchw_dim_to_nhwc_dim(dim_list): + assert len(dim_list) == 4 + new_dim = [dim_list[0], dim_list[2], dim_list[3], dim_list[1]] + return new_dim + + +def get_input_number(model): + if isinstance(model, str): + model = onnx.load(model) + inputs = model.graph.input + return len(inputs) + +def get_batch_size(model): + if isinstance(model, str): + model = onnx.load(model) + inputs = model.graph.input + return get_tensor_dim(inputs[0])[0] + + +def count_op_type(model): + if isinstance(model, str): + model = onnx.load(model) + + nodes = model.graph.node + + node2count = defaultdict(int) + for i in nodes: + node2count[i.op_type] += 1 + + return node2count + + +def contain_qlinear_opearator(onnx_model): + if isinstance(onnx_model, str): + onnx_model = onnx.load(onnx_model) + + nodes = onnx_model.graph.node + + for i in nodes: + op_type = i.op_type.lower() + if op_type.startswith("qlinear") or op_type.startswith("qgemm"): + return True + return False + + +def get_all_node_name(model, exclude_constant=False, pretty_print=False): + if isinstance(model, str): + model = onnx.load(model) + + nodes = model.graph.node + if exclude_constant: + all_node = [i.name for i in nodes if i.op_type != "Constant"] + else: + all_node = [i.name for i in nodes] + + all_node.sort() + if pretty_print: + res = [f'"{i}"' for i in all_node] + res = ",\n".join(res) + res = f'[\n{res}\n]' + print(res) + + return all_node + +def rewrite_int64_input_to_int32(model): + inputs = model.graph.input + + for i in inputs: + if i.type.tensor_type.elem_type == 7: + i.type.tensor_type.elem_type = 6 + + print(inputs) + import pdb;pdb.set_trace() + onnx.checker.check_model(model) + return model \ No newline at end of file diff --git a/toolbox/ByteMLPerf/byte_infer_perf/general_perf/version.py b/toolbox/ByteMLPerf/byte_infer_perf/general_perf/version.py index 608f35d6f6b03ca23f46fbd6500fc32f694a858f..1f356cc57bfa00a3b251402604c54702fb414c96 100644 --- a/toolbox/ByteMLPerf/byte_infer_perf/general_perf/version.py +++ b/toolbox/ByteMLPerf/byte_infer_perf/general_perf/version.py @@ -1 +1 @@ -__version__ = '1.0.0' +__version__ = '1.0.0'