Files
UnrealEngine/Engine/Source/ThirdParty/MaterialX/MaterialX-1.38.10/source/MaterialXGenShader/ShaderGraph.h
2025-05-18 13:04:45 +08:00

264 lines
10 KiB
C++

//
// Copyright Contributors to the MaterialX Project
// SPDX-License-Identifier: Apache-2.0
//
#ifndef MATERIALX_SHADERGRAPH_H
#define MATERIALX_SHADERGRAPH_H
/// @file
/// Shader graph class
#include <MaterialXGenShader/Export.h>
#include <MaterialXGenShader/ColorManagementSystem.h>
#include <MaterialXGenShader/ShaderNode.h>
#include <MaterialXGenShader/Syntax.h>
#include <MaterialXGenShader/TypeDesc.h>
#include <MaterialXGenShader/UnitSystem.h>
#include <MaterialXCore/Document.h>
#include <MaterialXCore/Node.h>
MATERIALX_NAMESPACE_BEGIN
class Syntax;
class ShaderGraphEdge;
class ShaderGraphEdgeIterator;
class GenOptions;
/// An internal input socket in a shader graph,
/// used for connecting internal nodes to the outside
using ShaderGraphInputSocket = ShaderOutput;
/// An internal output socket in a shader graph,
/// used for connecting internal nodes to the outside
using ShaderGraphOutputSocket = ShaderInput;
/// A shared pointer to a shader graph
using ShaderGraphPtr = shared_ptr<class ShaderGraph>;
/// @class ShaderGraph
/// Class representing a graph (DAG) for shader generation
class MX_GENSHADER_API ShaderGraph : public ShaderNode
{
public:
/// Constructor.
ShaderGraph(const ShaderGraph* parent, const string& name, ConstDocumentPtr document, const StringSet& reservedWords);
/// Desctructor.
virtual ~ShaderGraph() { }
/// Create a new shader graph from an element.
/// Supported elements are outputs and shader nodes.
static ShaderGraphPtr create(const ShaderGraph* parent, const string& name, ElementPtr element,
GenContext& context);
/// Create a new shader graph from a nodegraph.
static ShaderGraphPtr create(const ShaderGraph* parent, const NodeGraph& nodeGraph,
GenContext& context);
/// Return true if this node is a graph.
bool isAGraph() const override { return true; }
/// Get an internal node by name
ShaderNode* getNode(const string& name);
/// Get an internal node by name
const ShaderNode* getNode(const string& name) const;
/// Get a vector of all nodes in order
const vector<ShaderNode*>& getNodes() const { return _nodeOrder; }
/// Get number of input sockets
size_t numInputSockets() const { return numOutputs(); }
/// Get number of output sockets
size_t numOutputSockets() const { return numInputs(); }
/// Get socket by index
ShaderGraphInputSocket* getInputSocket(size_t index) { return getOutput(index); }
ShaderGraphOutputSocket* getOutputSocket(size_t index = 0) { return getInput(index); }
const ShaderGraphInputSocket* getInputSocket(size_t index) const { return getOutput(index); }
const ShaderGraphOutputSocket* getOutputSocket(size_t index = 0) const { return getInput(index); }
/// Get socket by name
ShaderGraphInputSocket* getInputSocket(const string& name) { return getOutput(name); }
ShaderGraphOutputSocket* getOutputSocket(const string& name) { return getInput(name); }
const ShaderGraphInputSocket* getInputSocket(const string& name) const { return getOutput(name); }
const ShaderGraphOutputSocket* getOutputSocket(const string& name) const { return getInput(name); }
/// Get vector of sockets
const vector<ShaderGraphInputSocket*>& getInputSockets() const { return _outputOrder; }
const vector<ShaderGraphOutputSocket*>& getOutputSockets() const { return _inputOrder; }
/// Apply color and unit transforms to each input of a node.
void applyInputTransforms(ConstNodePtr node, ShaderNodePtr shaderNode, GenContext& context);
/// Create a new node in the graph
ShaderNode* createNode(ConstNodePtr node, GenContext& context);
/// Add input/output sockets
ShaderGraphInputSocket* addInputSocket(const string& name, const TypeDesc* type);
ShaderGraphOutputSocket* addOutputSocket(const string& name, const TypeDesc* type);
/// Add a default geometric node and connect to the given input.
void addDefaultGeomNode(ShaderInput* input, const GeomPropDef& geomprop, GenContext& context);
/// Sort the nodes in topological order.
void topologicalSort();
/// Return an iterator for traversal upstream from the given output
static ShaderGraphEdgeIterator traverseUpstream(ShaderOutput* output);
/// Return the map of unique identifiers used in the scope of this graph.
IdentifierMap& getIdentifierMap() { return _identifiers; }
protected:
/// Create node connections corresponding to the connection between a pair of elements.
/// @param downstreamElement Element representing the node to connect to.
/// @param upstreamElement Element representing the node to connect from
/// @param connectingElement If non-null, specifies the element on on the downstream node to connect to.
/// @param context Context for generation.
void createConnectedNodes(const ElementPtr& downstreamElement,
const ElementPtr& upstreamElement,
ElementPtr connectingElement,
GenContext& context);
/// Add a node to the graph
void addNode(ShaderNodePtr node);
/// Add input sockets from an interface element (nodedef, nodegraph or node)
void addInputSockets(const InterfaceElement& elem, GenContext& context);
/// Add output sockets from an interface element (nodedef, nodegraph or node)
void addOutputSockets(const InterfaceElement& elem);
/// Traverse from the given root element and add all dependencies upstream.
/// The traversal is done in the context of a material, if given, to include
/// bind input elements in the traversal.
void addUpstreamDependencies(const Element& root, GenContext& context);
/// Add a color transform node and connect to the given input.
void addColorTransformNode(ShaderInput* input, const ColorSpaceTransform& transform, GenContext& context);
/// Add a color transform node and connect to the given output.
void addColorTransformNode(ShaderOutput* output, const ColorSpaceTransform& transform, GenContext& context);
/// Add a unit transform node and connect to the given input.
void addUnitTransformNode(ShaderInput* input, const UnitTransform& transform, GenContext& context);
/// Add a unit transform node and connect to the given output.
void addUnitTransformNode(ShaderOutput* output, const UnitTransform& transform, GenContext& context);
/// Perform all post-build operations on the graph.
void finalize(GenContext& context);
/// Optimize the graph, removing redundant paths.
void optimize(GenContext& context);
/// Bypass a node for a particular input and output,
/// effectively connecting the input's upstream connection
/// with the output's downstream connections.
void bypass(GenContext& context, ShaderNode* node, size_t inputIndex, size_t outputIndex = 0);
/// For inputs and outputs in the graph set the variable names to be used
/// in generated code. Making sure variable names are valid and unique
/// to avoid name conflicts during shader generation.
void setVariableNames(GenContext& context);
/// Populate the color transform map for the given shader port, if the provided combination of
/// source and target color spaces are supported for its data type.
void populateColorTransformMap(ColorManagementSystemPtr colorManagementSystem, ShaderPort* shaderPort,
const string& sourceColorSpace, const string& targetColorSpace, bool asInput);
/// Populates the appropriate unit transform map if the provided input/parameter or output
/// has a unit attribute and is of the supported type
void populateUnitTransformMap(UnitSystemPtr unitSystem, ShaderPort* shaderPort, ValueElementPtr element, const string& targetUnitSpace, bool asInput);
/// Break all connections on a node
void disconnect(ShaderNode* node) const;
ConstDocumentPtr _document;
std::unordered_map<string, ShaderNodePtr> _nodeMap;
std::vector<ShaderNode*> _nodeOrder;
IdentifierMap _identifiers;
// Temporary storage for inputs that require color transformations
std::unordered_map<ShaderInput*, ColorSpaceTransform> _inputColorTransformMap;
// Temporary storage for inputs that require unit transformations
std::unordered_map<ShaderInput*, UnitTransform> _inputUnitTransformMap;
// Temporary storage for outputs that require color transformations
std::unordered_map<ShaderOutput*, ColorSpaceTransform> _outputColorTransformMap;
// Temporary storage for outputs that require unit transformations
std::unordered_map<ShaderOutput*, UnitTransform> _outputUnitTransformMap;
};
/// @class ShaderGraphEdge
/// An edge returned during shader graph traversal.
class MX_GENSHADER_API ShaderGraphEdge
{
public:
ShaderGraphEdge(ShaderOutput* up, ShaderInput* down) :
upstream(up),
downstream(down)
{
}
ShaderOutput* upstream;
ShaderInput* downstream;
};
/// @class ShaderGraphEdgeIterator
/// Iterator class for traversing edges between nodes in a shader graph.
class MX_GENSHADER_API ShaderGraphEdgeIterator
{
public:
ShaderGraphEdgeIterator(ShaderOutput* output);
~ShaderGraphEdgeIterator() { }
bool operator==(const ShaderGraphEdgeIterator& rhs) const
{
return _upstream == rhs._upstream &&
_downstream == rhs._downstream &&
_stack == rhs._stack;
}
bool operator!=(const ShaderGraphEdgeIterator& rhs) const
{
return !(*this == rhs);
}
/// Dereference this iterator, returning the current output in the traversal.
ShaderGraphEdge operator*() const
{
return ShaderGraphEdge(_upstream, _downstream);
}
/// Iterate to the next edge in the traversal.
/// @throws ExceptionFoundCycle if a cycle is encountered.
ShaderGraphEdgeIterator& operator++();
/// Return a reference to this iterator to begin traversal
ShaderGraphEdgeIterator& begin()
{
return *this;
}
/// Return the end iterator.
static const ShaderGraphEdgeIterator& end();
private:
void extendPathUpstream(ShaderOutput* upstream, ShaderInput* downstream);
void returnPathDownstream(ShaderOutput* upstream);
ShaderOutput* _upstream;
ShaderInput* _downstream;
using StackFrame = std::pair<ShaderOutput*, size_t>;
std::vector<StackFrame> _stack;
std::set<ShaderOutput*> _path;
};
MATERIALX_NAMESPACE_END
#endif