From e1c516b18850fa7a32d7415c1a1021daf4aa5746 Mon Sep 17 00:00:00 2001 From: fallenoak Date: Thu, 29 Dec 2022 22:18:14 -0600 Subject: [PATCH] feat(thread): add OsTls functions --- common/CMakeLists.txt | 1 + common/Thread.hpp | 6 +++ common/thread/Tls.cpp | 113 ++++++++++++++++++++++++++++++++++++++++++ common/thread/Tls.hpp | 12 +++++ test/Thread.cpp | 26 ++++++++++ 5 files changed, 158 insertions(+) create mode 100644 common/Thread.hpp create mode 100644 common/thread/Tls.cpp create mode 100644 common/thread/Tls.hpp create mode 100644 test/Thread.cpp diff --git a/common/CMakeLists.txt b/common/CMakeLists.txt index da3b683..130e30f 100644 --- a/common/CMakeLists.txt +++ b/common/CMakeLists.txt @@ -3,6 +3,7 @@ file(GLOB COMMON_SOURCES "objectalloc/*.cpp" "ref/*.cpp" "string/*.cpp" + "thread/*.cpp" ) add_library(common STATIC diff --git a/common/Thread.hpp b/common/Thread.hpp new file mode 100644 index 0000000..3b47b25 --- /dev/null +++ b/common/Thread.hpp @@ -0,0 +1,6 @@ +#ifndef COMMON_THREAD_HPP +#define COMMON_THREAD_HPP + +#include "common/thread/Tls.hpp" + +#endif diff --git a/common/thread/Tls.cpp b/common/thread/Tls.cpp new file mode 100644 index 0000000..fe96d4f --- /dev/null +++ b/common/thread/Tls.cpp @@ -0,0 +1,113 @@ +#include "common/thread/Tls.hpp" +#include + +#if defined(WHOA_SYSTEM_WIN) +#include +#endif + +#if defined(WHOA_SYSTEM_MAC) || defined(WHOA_SYSTEM_LINUX) +#include +#include +#include +#endif + +#if defined(WHOA_SYSTEM_MAC) || defined(WHOA_SYSTEM_LINUX) +typedef void* TLSData; + +struct TLSSlot : public TSLinkedNode { + TSGrowableArray storage; +}; + +namespace OsTls { + int8_t s_initialized; + int32_t s_nextIndex; + TSList> s_tlsCleanupList; + pthread_key_t s_tlsKey; + SCritSect s_tlsLock; +} +#endif + +int32_t OsTlsAlloc() { +#if defined(WHOA_SYSTEM_WIN) + return TlsAlloc(); +#endif + +#if defined(WHOA_SYSTEM_MAC) || defined(WHOA_SYSTEM_LINUX) + OsTls::s_tlsLock.Enter(); + + if (!OsTls::s_initialized) { + OsTls::s_initialized = 1; + OsTls::s_nextIndex = 0; + pthread_key_create(&OsTls::s_tlsKey, nullptr); + } + + int32_t index = OsTls::s_nextIndex++; + + OsTls:: s_tlsLock.Leave(); + + return index; +#endif +} + +void* OsTlsGetValue(uint32_t tlsIndex) { +#if defined(WHOA_SYSTEM_WIN) + return TlsGetValue(tlsIndex); +#endif + +#if defined(WHOA_SYSTEM_MAC) || defined(WHOA_SYSTEM_LINUX) + if (!OsTls::s_initialized) { + return nullptr; + } + + auto slot = static_cast(pthread_getspecific(OsTls::s_tlsKey)); + + if (!slot) { + OsTls::s_tlsLock.Enter(); + + slot = OsTls::s_tlsCleanupList.NewNode(1, 1, 0x8); + + OsTls::s_tlsLock.Leave(); + + pthread_setspecific(OsTls::s_tlsKey, slot); + } + + if (slot->storage.Count() > tlsIndex) { + return slot->storage[tlsIndex]; + } else { + return nullptr; + } +#endif +} + +int32_t OsTlsSetValue(uint32_t tlsIndex, void* tlsValue) { +#if defined(WHOA_SYSTEM_WIN) + return TlsSetValue(tlsIndex, tlsValue); +#endif + +#if defined(WHOA_SYSTEM_MAC) || defined(WHOA_SYSTEM_LINUX) + if (!OsTls::s_initialized) { + return 0; + } + + auto slot = static_cast(pthread_getspecific(OsTls::s_tlsKey)); + + if (!slot) { + OsTls::s_tlsLock.Enter(); + + slot = OsTls::s_tlsCleanupList.NewNode(1, 1, 0x8); + + OsTls::s_tlsLock.Leave(); + + pthread_setspecific(OsTls::s_tlsKey, slot); + } + + if (slot->storage.Count() > tlsIndex) { + slot->storage[tlsIndex] = tlsValue; + } else { + slot->storage.GrowToFit(tlsIndex, 0); + slot->storage[tlsIndex] = tlsValue; + } + + return 1; +#endif +} diff --git a/common/thread/Tls.hpp b/common/thread/Tls.hpp new file mode 100644 index 0000000..b901fc1 --- /dev/null +++ b/common/thread/Tls.hpp @@ -0,0 +1,12 @@ +#ifndef COMMON_THREAD_TLS_HPP +#define COMMON_THREAD_TLS_HPP + +#include + +int32_t OsTlsAlloc(); + +void* OsTlsGetValue(uint32_t tlsIndex); + +int32_t OsTlsSetValue(uint32_t tlsIndex, void* tlsValue); + +#endif diff --git a/test/Thread.cpp b/test/Thread.cpp new file mode 100644 index 0000000..51162d1 --- /dev/null +++ b/test/Thread.cpp @@ -0,0 +1,26 @@ +#include "common/Thread.hpp" +#include "test/Test.hpp" + +TEST_CASE("OsTlsAlloc", "[thread]") { + SECTION("allocates tls index") { + int32_t tlsIndex = OsTlsAlloc(); + REQUIRE(tlsIndex >= 0); + } +} + +TEST_CASE("OsTlsSetValue", "[thread]") { + SECTION("sets value in tls index") { + int32_t tlsIndex = OsTlsAlloc(); + uint32_t tlsValue = 123; + REQUIRE(OsTlsSetValue(tlsIndex, &tlsValue)); + } +} + +TEST_CASE("OsTlsGetValue", "[thread]") { + SECTION("gets value in tls index") { + int32_t tlsIndex = OsTlsAlloc(); + uint32_t tlsValue = 456; + OsTlsSetValue(tlsIndex, &tlsValue); + REQUIRE(OsTlsGetValue(tlsIndex) == &tlsValue); + } +}