// Copyright Epic Games, Inc. All Rights Reserved. #include "UnsyncCommon.h" UNSYNC_THIRD_PARTY_INCLUDES_START #if UNSYNC_PLATFORM_WINDOWS # include // must be first # include # pragma comment(lib, "Ws2_32.lib") #endif // UNSYNC_PLATFORM_WINDOWS #if UNSYNC_PLATFORM_UNIX # include # include # include # include # include # include # define SOCKET int # define INVALID_SOCKET (SOCKET)(~0) # define SOCKET_ERROR (-1) # define closesocket(x) close(x) #endif // UNSYNC_PLATFORM_UNIX #include #include #include UNSYNC_THIRD_PARTY_INCLUDES_END #include "UnsyncBuffer.h" #include "UnsyncHash.h" #include "UnsyncLog.h" #include "UnsyncSocket.h" #include "UnsyncUtil.h" namespace unsync { static_assert(sizeof(FSocketAddress) >= sizeof(sockaddr), "SocketAddress is too small"); static_assert(sizeof(FSocketAddress) >= sizeof(sockaddr_in), "SocketAddress is too small"); const FSocketHandle InvalidSocketHandle = INVALID_SOCKET; struct FSocketInitHelper { FSocketInitHelper() { #if UNSYNC_PLATFORM_WINDOWS WSADATA Wsa; if (WSAStartup(MAKEWORD(2, 2), &Wsa) != 0) { UNSYNC_ERROR(L"WSAStartup failed: %d", WSAGetLastError()); } #endif // UNSYNC_PLATFORM_WINDOWS tls_init(); } ~FSocketInitHelper() { #if UNSYNC_PLATFORM_WINDOWS WSACleanup(); #endif // UNSYNC_PLATFORM_WINDOWS } }; static void LazyInitSockets() { static FSocketInitHelper InitHelper; } std::string GetCurrentHostName() { LazyInitSockets(); char Buffer[1024] = {}; if (gethostname(Buffer, (int)sizeof(Buffer)) == 0) { return std::string(Buffer); } else { return {}; } } static int32 GetLastSocketError() { #if UNSYNC_PLATFORM_WINDOWS return WSAGetLastError(); #else return -1; // TODO: report unix socket error #endif } static void ReportSocketError(int32 ErrorCode, ELogLevel LogLevel) { #if UNSYNC_PLATFORM_WINDOWS ErrorCode = WSAGetLastError(); #endif LogPrintf(LogLevel, L"Socket error: %d\n", ErrorCode); UNSYNC_BREAK_ON_ERROR_LEVEL(LogLevel); } static void ReportSocketError(tls* TlsCtx, ELogLevel LogLevel) { const char* ErrorString = tls_error(TlsCtx); LogPrintf(LogLevel, L"TLS error: %hs\n", ErrorString); UNSYNC_BREAK_ON_ERROR_LEVEL(LogLevel); } FSocketHandle SocketListenTcp(const char* Address, uint16 Port) { LazyInitSockets(); SOCKET ListenSocket = socket(AF_INET, SOCK_STREAM, 0); if (ListenSocket == InvalidSocketHandle) { UNSYNC_ERROR(L"Failed to create TCP socket (error code %d)", GetLastSocketError()); return InvalidSocketHandle; } sockaddr_in Service; Service.sin_family = AF_INET; Service.sin_addr.s_addr = inet_addr(Address); Service.sin_port = htons(Port); int32 BindResult = bind(ListenSocket, (sockaddr*)&Service, sizeof(Service)); if (BindResult == SOCKET_ERROR) { UNSYNC_ERROR(L"Failed to bind TCP socket (error code %d)", GetLastSocketError()); SocketClose(ListenSocket); return InvalidSocketHandle; } int32 ListenResult = listen(ListenSocket, 1); if (ListenResult == SOCKET_ERROR) { UNSYNC_ERROR(L"Failed to listen on TCP socket (error code %d)", GetLastSocketError()); SocketClose(ListenSocket); return InvalidSocketHandle; } return ListenSocket; } FSocketHandle SocketAccept(FSocketHandle ListenSocket) { SOCKET AcceptSocket = accept(ListenSocket, nullptr, nullptr); if (AcceptSocket == InvalidSocketHandle) { UNSYNC_ERROR(L"Failed to accept connection on TCP socket (error code %d)", GetLastSocketError()); return InvalidSocketHandle; } return AcceptSocket; } FSocketHandle SocketConnectTcp(const char* DestAddress, uint16 Port) { LazyInitSockets(); SOCKET Sock = socket(AF_INET, SOCK_STREAM, 0); if (Sock == InvalidSocketHandle) { UNSYNC_ERROR(L"Failed to create TCP socket (error code %d)", GetLastSocketError()); return InvalidSocketHandle; } sockaddr_in SockAddr = {}; SockAddr.sin_family = AF_INET; SockAddr.sin_port = htons(Port); struct addrinfo AddrHints = {}; AddrHints.ai_family = PF_UNSPEC; AddrHints.ai_socktype = SOCK_STREAM; AddrHints.ai_flags = AI_CANONNAME; struct addrinfo* Addr = {}; char ResolvedAddress[100] = {}; int Err = getaddrinfo(DestAddress, nullptr, &AddrHints, &Addr); if (Err == 0) { struct addrinfo* Curr = Addr; while (Curr) { if (Curr->ai_family == SockAddr.sin_family) { void* AddrData = nullptr; AddrData = &((struct sockaddr_in*)Curr->ai_addr)->sin_addr; if (inet_ntop(Curr->ai_family, AddrData, ResolvedAddress, sizeof(ResolvedAddress))) { DestAddress = ResolvedAddress; break; } } Curr = Curr->ai_next; } } else { UNSYNC_ERROR(L"Invalid address '%hs'", DestAddress); return InvalidSocketHandle; } if (Addr) { freeaddrinfo(Addr); } if (inet_pton(SockAddr.sin_family, DestAddress, &SockAddr.sin_addr) <= 0) { UNSYNC_ERROR(L"Invalid address '%hs'", DestAddress); return InvalidSocketHandle; } int32 ConnectResult = connect(Sock, (struct sockaddr*)&SockAddr, sizeof(SockAddr)); if (ConnectResult < 0) { closesocket(Sock); UNSYNC_LOG(L"Warning: Failed to connect to '%hs' on port %d", DestAddress, Port); return InvalidSocketHandle; } return Sock; } void SocketClose(FSocketHandle Socket) { if (Socket != InvalidSocketHandle) { closesocket(Socket); } } int SocketSend(FSocketHandle Socket, const void* Data, size_t DataSize) { if (DataSize == 0) { return 0; } UNSYNC_ASSERT(DataSize < INT_MAX); int Result = send(Socket, reinterpret_cast(Data), (int)DataSize, 0); if (Result < 0) { ReportSocketError(Result, ELogLevel::Warning); return 0; } else { return Result; } } int SocketRecvAll(FSocketHandle Socket, void* Data, size_t DataSize) { if (DataSize == 0) return 0; UNSYNC_ASSERT(DataSize < INT_MAX); // If socket timeout is set, recv may read partial data and return the // number of bytes read before timeout or <=0 if another error occurred. int ProcessedBytes = 0; while (ProcessedBytes < int(DataSize)) { char* BatchPtr = reinterpret_cast(Data) + ProcessedBytes; int BatchSize = (int)DataSize - ProcessedBytes; int Res = recv(Socket, BatchPtr, BatchSize, MSG_WAITALL); if (Res <= 0) { if (Res < 0) { ReportSocketError(Res, ELogLevel::Warning); } break; } ProcessedBytes += Res; } return ProcessedBytes; } int SocketRecvAny(FSocketHandle Socket, void* Data, size_t DataSize) { if (DataSize == 0) return 0; UNSYNC_ASSERT(DataSize < INT_MAX); int Res = recv(Socket, reinterpret_cast(Data), (int)DataSize, 0); if (Res < 0) { ReportSocketError(Res, ELogLevel::Warning); return 0; } return Res; } bool SocketSetRecvTimeout(FSocketHandle Socket, uint32 Seconds) { #if UNSYNC_PLATFORM_WINDOWS uint32 Timeout = Seconds * 1000; // Winsock uses timeout in milliseconds #else struct timeval Timeout; Timeout.tv_sec = long(Seconds); Timeout.tv_usec = 0; #endif int Result = setsockopt(Socket, SOL_SOCKET, SO_RCVTIMEO, reinterpret_cast(&Timeout), int(sizeof(Timeout))); return Result == 0; } bool SocketValid(FSocketHandle Socket) { if (Socket == INVALID_SOCKET) { return false; } int RecvResult = send(Socket, nullptr, 0, 0); return RecvResult != SOCKET_ERROR; } FSocketAddress SocketMakeAddress(const char* Address, uint16 Port) { sockaddr_in Addr = {}; Addr.sin_family = AF_INET; Addr.sin_addr.s_addr = inet_addr(Address); Addr.sin_port = htons(Port); FSocketAddress Result; memcpy(&Result, &Addr, sizeof(Addr)); return Result; } int FSocketRaw::Send(const void* Data, size_t DataSize) { return SocketSend(Handle, Data, DataSize); } int FSocketRaw::RecvAll(void* Data, size_t DataSize) { return SocketRecvAll(Handle, Data, DataSize); } int FSocketRaw::RecvAny(void* Data, size_t DataSize) { return SocketRecvAny(Handle, Data, DataSize); } static ssize_t TlsReadCallback(struct tls* TlsCtx, void* Buffer, size_t BufferSize, void* UserData) { FSocketHandle Socket = reinterpret_cast(UserData)->Handle; return SocketRecvAll(Socket, Buffer, BufferSize); } static ssize_t TlsWriteCallback(struct tls* TlsCtx, const void* Buffer, size_t BufferSize, void* UserData) { FSocketHandle Socket = reinterpret_cast(UserData)->Handle; return SocketSend(Socket, Buffer, BufferSize); } FSocketTls::FSocketTls(FSocketHandle InHandle, FTlsClientSettings ClientSettings) : FSocketBase(InHandle) { Security = ESocketSecurity::Unknown; tls_config* TlsCfg = tls_config_new(); if (ClientSettings.CACert.Data) { tls_config_set_ca_mem(TlsCfg, ClientSettings.CACert.Data, (size_t)ClientSettings.CACert.Size); } else { const FBuffer& SystemRootCerts = GetSystemRootCerts(); tls_config_set_ca_mem(TlsCfg, SystemRootCerts.Data(), SystemRootCerts.Size()); } if (!ClientSettings.bVerifyCertificate) { tls_config_insecure_noverifycert(TlsCfg); } if (!ClientSettings.bVerifySubject) { tls_config_insecure_noverifyname(TlsCfg); } TlsCtx = tls_client(); int Err = 0; if (Err == 0) { tls_configure(TlsCtx, TlsCfg); } tls_config_free(TlsCfg); if (Err == 0) { UNSYNC_ASSERT(!ClientSettings.Subject.empty()); std::string TlsSubject(ClientSettings.Subject); Err = tls_connect_cbs(TlsCtx, TlsReadCallback, TlsWriteCallback, this, TlsSubject.c_str()); } if (Err == 0) { for(;;) { Err = tls_handshake(TlsCtx); if (Err != TLS_WANT_POLLIN && Err != TLS_WANT_POLLOUT) { break; } } } if (Err) { const char* TlsErrorMsg = tls_error(TlsCtx); if (TlsErrorMsg) { UNSYNC_LOG(L"Warning: Failed to establish TLS connection: %hs", TlsErrorMsg); } else { UNSYNC_LOG(L"Warning: Failed to establish TLS connection"); } tls_free(TlsCtx); TlsCtx = nullptr; } else { const char* ConnVersion = tls_conn_version(TlsCtx); if (ConnVersion) { if (!strcmp(ConnVersion, "TLSv1.3")) { Security = ESocketSecurity::TLSv1_3; } else if (!strcmp(ConnVersion, "TLSv1.2")) { Security = ESocketSecurity::TLSv1_2; } } } } FSocketTls::~FSocketTls() { if (TlsCtx) { for(;;) { int Err = tls_close(TlsCtx); if (Err != TLS_WANT_POLLIN && Err != TLS_WANT_POLLOUT) { break; } } tls_free(TlsCtx); } } static ssize_t TlsWrite(struct tls* TlsCtx, const char* Buffer, size_t BufferSize) { ssize_t ErrorCode; do { ErrorCode = tls_write(TlsCtx, Buffer, BufferSize); } while(ErrorCode == TLS_WANT_POLLIN || ErrorCode == TLS_WANT_POLLOUT); return ErrorCode; } int FSocketTls::Send(const void* Data, size_t DataSize) { if (DataSize == 0) return 0; UNSYNC_ASSERT(DataSize < INT_MAX); int ProcessedBytes = 0; while (ProcessedBytes < DataSize) { ssize_t ErrorCode = TlsWrite(TlsCtx, (const char*)Data + ProcessedBytes, DataSize - ProcessedBytes); if (ErrorCode <= 0) { if (ErrorCode < 0) { ReportSocketError(TlsCtx, ELogLevel::Warning); } break; } int Res = CheckedNarrow(ErrorCode); ProcessedBytes += Res; } return ProcessedBytes; } static ssize_t TlsRead(struct tls* TlsCtx, char* Buffer, size_t BufferSize) { ssize_t ErrorCode; do { ErrorCode = tls_read(TlsCtx, Buffer, BufferSize); } while(ErrorCode == TLS_WANT_POLLIN || ErrorCode == TLS_WANT_POLLOUT); return ErrorCode; } int FSocketTls::RecvAll(void* Data, size_t DataSize) { if (DataSize == 0) return 0; UNSYNC_ASSERT(DataSize < INT_MAX); int ProcessedBytes = 0; while (ProcessedBytes < DataSize) { ssize_t ErrorCode = TlsRead(TlsCtx, (char*)Data + ProcessedBytes, DataSize - ProcessedBytes); if (ErrorCode <= 0) { if (ErrorCode < 0) { ReportSocketError(TlsCtx, ELogLevel::Warning); } break; } int Res = CheckedNarrow(ErrorCode); ProcessedBytes += Res; } return ProcessedBytes; } int FSocketTls::RecvAny(void* Data, size_t DataSize) { if (DataSize == 0) return 0; UNSYNC_ASSERT(DataSize < INT_MAX); int ProcessedBytes = 0; do { ssize_t ErrorCode = TlsRead(TlsCtx, (char*)Data + ProcessedBytes, DataSize - ProcessedBytes); if (ErrorCode <= 0) { if (ErrorCode < 0) { ReportSocketError(TlsCtx, ELogLevel::Warning); } break; } int Res = CheckedNarrow(ErrorCode); ProcessedBytes += Res; } while (false); return ProcessedBytes; } FSocketBase::~FSocketBase() { SocketClose(Handle); } const char* ToString(ESocketSecurity Security) { switch (Security) { case ESocketSecurity::None: return "None"; case ESocketSecurity::TLSv1_2: return "TLS 1.2"; case ESocketSecurity::TLSv1_3: return "TLS 1.3"; default: return "Unknown"; } } bool SendBuffer(FSocketBase& Socket, const FBufferView& Data) { UNSYNC_ASSERT(Data.Size <= std::numeric_limits::max()); int32 Size = int32(Data.Size); bool bOk = true; bOk &= SocketSendT(Socket, Size); bOk &= (SocketSend(Socket, Data.Data, Size) == Size); return bOk; } bool SendBuffer(FSocketBase& Socket, const FBuffer& Data) { return SendBuffer(Socket, Data.View()); } } // namespace unsync