Files
UnrealEngine/Engine/Plugins/Interchange/Runtime/Source/Dispatcher/Private/InterchangeDispatcherNetworking.cpp
2025-05-18 13:04:45 +08:00

375 lines
10 KiB
C++

// Copyright Epic Games, Inc. All Rights Reserved.
#include "InterchangeDispatcherNetworking.h"
#include "CoreMinimal.h"
#include "HAL/PlatformProcess.h"
#include "HAL/PlatformTime.h"
#include "InterchangeDispatcherLog.h"
#include "Misc/ScopeLock.h"
#include "SocketSubsystem.h"
#include "Sockets.h"
#include "Serialization/MemoryReader.h"
#include "Serialization/MemoryWriter.h"
namespace UE
{
namespace Interchange
{
static FTimespan TimespanFromSeconds(double Seconds)
{
return FTimespan(int64(Seconds * ETimespan::TicksPerSecond));
}
FNetworkNode::~FNetworkNode()
{
CloseSocket(ConnectedSocket);
}
bool FNetworkNode::SendMessage(const TArray<uint8>& Message, double Timeout_s)
{
FScopeLock Lock(&SendReceiveCriticalSection);
if (bWriteError)
{
UE_LOG(LogInterchangeDispatcher, Display, TEXT("bWriteError flag raised, can't write"));
return false;
}
if (ConnectedSocket == nullptr)
{
UE_LOG(LogInterchangeDispatcher, Display, TEXT("node not connected, can't write"));
return false;
}
ESocketConnectionState State = ConnectedSocket->GetConnectionState();
if (State == ESocketConnectionState::SCS_ConnectionError)
{
UE_LOG(LogInterchangeDispatcher, Display, TEXT("connection state error"));
return false;
}
bool bCanWrite = ConnectedSocket->Wait(ESocketWaitConditions::WaitForWrite, TimespanFromSeconds(Timeout_s));
if (!bCanWrite)
{
UE_LOG(LogInterchangeDispatcher, Display, TEXT("can't write on socket"));
return false;
}
ConnectedSocket->SetNonBlocking(false);
auto SendBuffer = [&](const TArray<uint8>& Buffer) -> bool
{
uint32 BufferSize = Buffer.Num();
const uint8* Data = Buffer.GetData();
uint32 TotalByteSent = 0;
bool bSendSucceed = true;
while (bSendSucceed && TotalByteSent < BufferSize)
{
int32 BytesSent = 0;
bSendSucceed &= ConnectedSocket->Send(Data + TotalByteSent, BufferSize - TotalByteSent, BytesSent);
TotalByteSent += (uint32)BytesSent;
}
return bSendSucceed && TotalByteSent == BufferSize;
};
// Send the header
TArray<uint8> HeaderBuffer;
FMemoryWriter ArWriter(HeaderBuffer);
FMessageHeader Header;
Header.ByteSize = Message.Num();
ArWriter << Header;
if (!SendBuffer(HeaderBuffer))
{
UE_LOG(LogInterchangeDispatcher, Display, TEXT("can't write header on socket"));
bWriteError = true;
return false;
}
// Send the content
if (!SendBuffer(Message))
{
UE_LOG(LogInterchangeDispatcher, Display, TEXT("can't write content on socket"));
bWriteError = true;
return false;
}
return true;
}
bool FNetworkNode::ReceiveMessage(TArray<uint8>& OutMessage, double Timeout_s)
{
FScopeLock Lock(&SendReceiveCriticalSection);
if (bReadError)
{
UE_LOG(LogInterchangeDispatcher, Display, TEXT("ReadError flag raised, can't read"));
return false;
}
if (ConnectedSocket == nullptr)
{
UE_LOG(LogInterchangeDispatcher, Display, TEXT("node not connected, can't read"));
return false;
}
ConnectedSocket->SetNonBlocking(true);
bool bCanRead = ConnectedSocket && ConnectedSocket->Wait(ESocketWaitConditions::WaitForRead, TimespanFromSeconds(Timeout_s));
if (!bCanRead)
{
return false;
}
// Make sure we have a valid header
if (IncommingMessage.Header.ByteSize < 0)
{
static uint32 HeaderByteSize = []()
{
FMessageHeader Header;
TArray<uint8> HeaderBuffer;
FMemoryWriter ArWriter(HeaderBuffer);
ArWriter << Header;
return HeaderBuffer.Num();
}();
uint32 PendingByteCount;
if (!ConnectedSocket->HasPendingData(PendingByteCount) || PendingByteCount < HeaderByteSize)
{
return false;
}
TArray<uint8> HeaderBuffer;
HeaderBuffer.AddZeroed(HeaderByteSize);
int32 BytesRead = -1;
bool bRecvSucceed = ConnectedSocket->Recv(HeaderBuffer.GetData(), HeaderBuffer.Num(), BytesRead);
if (!bRecvSucceed || BytesRead != HeaderByteSize)
{
UE_LOG(LogInterchangeDispatcher, Display, TEXT("Parsed header failed"));
bReadError = true;
return false;
}
FMemoryReader ArReader(HeaderBuffer);
ArReader << IncommingMessage.Header;
if (IncommingMessage.Header.ByteSize < 0 || IncommingMessage.Header.ByteSize > 1 << 20)
{
UE_LOG(LogInterchangeDispatcher, Display, TEXT("Parsed header failed: bad message size %d"), IncommingMessage.Header.ByteSize);
bReadError = true;
return false;
}
IncommingMessage.Content.Reserve(IncommingMessage.Header.ByteSize);
}
// fill the message with available data
uint32 PendingByteCount;
if (IncommingMessage.Header.ByteSize >= 0 && ConnectedSocket->HasPendingData(PendingByteCount))
{
int32 MissingByteInCurrentMessage = IncommingMessage.Header.ByteSize - IncommingMessage.Content.Num();
int32 ReadTarget = int32(PendingByteCount) < MissingByteInCurrentMessage ? PendingByteCount : MissingByteInCurrentMessage;
uint8* Destination = IncommingMessage.Content.GetData() + IncommingMessage.Content.Num();
IncommingMessage.Content.AddUninitialized(ReadTarget);
int32 BytesRead = -1;
bool bRecvSucceed = ConnectedSocket->Recv(Destination, ReadTarget, BytesRead);
if (BytesRead != ReadTarget)
{
bReadError = true;
UE_LOG(LogInterchangeDispatcher, Display, TEXT("Recv issue"));
return false;
}
bool bMessageCompleted = BytesRead == MissingByteInCurrentMessage;
if (bMessageCompleted)
{
OutMessage = MoveTemp(IncommingMessage.Content);
IncommingMessage = FMessage();
return true;
}
}
return false;
}
FSocket* FNetworkNode::CreateInternalSocket(const FString& Description)
{
FSocket* Socket = nullptr;
ISocketSubsystem* SocketSubsystem = ISocketSubsystem::Get();
TSharedRef<FInternetAddr> InternetAddress = SocketSubsystem->CreateInternetAddr();
InternetAddress->SetLoopbackAddress();
FName Protocol = NAME_None;
if (ensure(InternetAddress->IsValid()))
{
Protocol = InternetAddress->GetProtocolType();
}
FSocket* NewSocket = SocketSubsystem->CreateSocket(NAME_Stream, *Description, Protocol);
UE_CLOG(!NewSocket, LogInterchangeDispatcher, Display, TEXT("Socket creation failure"));
return NewSocket;
}
void FNetworkNode::CloseSocket(FSocket*& Socket)
{
if (Socket)
{
Socket->Close();
ISocketSubsystem::Get()->DestroySocket(Socket);
Socket = nullptr;
}
}
bool FNetworkNode::IsConnected()
{
return ConnectedSocket && ConnectedSocket->GetConnectionState() == ESocketConnectionState::SCS_Connected;
}
FNetworkServerNode::FNetworkServerNode()
{
ListeningSocket = CreateInternalSocket(TEXT("Interchange listening socket"));
if (!ListeningSocket)
{
ConnectedSocketError = SocketErrorCode::Error_Create;
return;
}
ISocketSubsystem* SocketSubsystem = ISocketSubsystem::Get();
TSharedRef<FInternetAddr> InternetAddress = SocketSubsystem->CreateInternetAddr();
InternetAddress->SetLoopbackAddress();
int32 BoundPort = SocketSubsystem->BindNextPort(ListeningSocket, *InternetAddress, 1e3, 1);
if (BoundPort == 0)
{
ConnectedSocketError = SocketErrorCode::Error_Bind;
UE_LOG(LogInterchangeDispatcher, Display, TEXT("Socket binding failure"));
return;
}
ensure(BoundPort == ListeningSocket->GetPortNo());
if (!ListeningSocket->Listen(0))
{
ConnectedSocketError = SocketErrorCode::Error_Listen;
UE_LOG(LogInterchangeDispatcher, Display, TEXT("Socket listen failure"));
return;
}
}
FNetworkServerNode::~FNetworkServerNode()
{
CloseSocket(ListeningSocket);
}
int32 FNetworkServerNode::GetListeningPort()
{
return ListeningSocket ? ListeningSocket->GetPortNo() : 0;
}
bool FNetworkServerNode::Accept(const FString& Description, double Timeout_s)
{
CloseSocket(ConnectedSocket);
// wait until a connection occurs
ListeningSocket->SetNonBlocking();
double AcceptLimit_s = FPlatformTime::Seconds() + Timeout_s;
do
{
ConnectedSocket = ListeningSocket->Accept(Description);
if (ConnectedSocket)
{
return true;
}
FPlatformProcess::Sleep(0.1);
} while (FPlatformTime::Seconds() < AcceptLimit_s);
return false;
}
bool FNetworkClientNode::Connect(const FString& Description, int32 ServerPort, double Timeout_s)
{
CloseSocket(ConnectedSocket);
ConnectedSocket = CreateInternalSocket(Description);
if (!ConnectedSocket)
{
ConnectedSocketError = SocketErrorCode::Error_Create;
return false;
}
ISocketSubsystem* SocketSubsystem = ISocketSubsystem::Get();
TSharedRef<FInternetAddr> InternetAddress = SocketSubsystem->CreateInternetAddr();
InternetAddress->SetLoopbackAddress();
InternetAddress->SetPort(ServerPort);
double ConnectTimeout_s = FPlatformTime::Seconds() + Timeout_s;
ConnectedSocket->SetNonBlocking(true);
do
{
if (ConnectedSocket->Connect(*InternetAddress))
{
UE_LOG(LogInterchangeDispatcher, Verbose, TEXT("Client Node is connected"));
return true;
}
} while (FPlatformTime::Seconds() < ConnectTimeout_s);
UE_LOG(LogInterchangeDispatcher, Display, TEXT("Client socket failed to connect"));
CloseSocket(ConnectedSocket);
return false;
}
void FCommandQueue::SetNetworkInterface(FNetworkNode* InNetworkInterface)
{
NetworkInterface = InNetworkInterface;
}
TSharedPtr<ICommand> FCommandQueue::GetNextCommand(double Timeout_s)
{
// consume all available commands
while (Poll(Timeout_s))
{
Timeout_s = 0;
}
TSharedPtr<ICommand> OutCommand;
return InCommands.Dequeue(OutCommand) ? OutCommand : nullptr;
}
bool FCommandQueue::SendCommand(ICommand& Commmand, double Timeout_s)
{
TArray<uint8> CommandBuffer;
SerializeCommand(Commmand, CommandBuffer);
return NetworkInterface && NetworkInterface->SendMessage(CommandBuffer, Timeout_s);
}
void FCommandQueue::Disconnect(double Timeout_s)
{
// consume all available commands before closing the connection
while (Poll(Timeout_s))
{
Timeout_s = 0;
}
NetworkInterface = nullptr;
}
bool FCommandQueue::Poll(double Timeout_s)
{
TArray<uint8> CommandBuffer;
if (NetworkInterface && NetworkInterface->ReceiveMessage(CommandBuffer, Timeout_s))
{
if (TSharedPtr<ICommand> Command = DeserializeCommand(CommandBuffer))
{
InCommands.Enqueue(Command);
return true;
}
}
return false;
}
} //ns Interchange
}//ns UE