#include <torch/csrc/jit/api/function_impl.h> #include <torch/csrc/jit/passes/inliner.h> #include <torch/csrc/jit/frontend/error_report.h> #include <torch/csrc/jit/passes/constant_pooling.h> #include <torch/csrc/jit/passes/constant_propagation.h> #include <torch/csrc/jit/passes/peephole.h> namespace torch { namespace jit { namespace { c10::FunctionSchema defaultSchemaFor(const Function& function) { std::vector<c10::Argument> args; std::vector<c10::Argument> returns; Graph& g = *function.graph(); size_t num_inputs = function.num_inputs(); for (size_t i = 0; i < num_inputs; ++i) { const Value* v = g.inputs().at(i); std::string name = v->hasDebugName() ? v->debugNameBase() : ("argument_" + c10::to_string(i)); args.emplace_back(std::move(name), unshapedType(g.inputs()[i]->type())); } for (size_t i = 0; i < g.outputs().size(); ++i) { returns.emplace_back("", unshapedType(g.outputs()[i]->type())); } return {function.name(), "", std::move(args), std::move(returns)}; } } // namespace void placeholderCreator(GraphFunction&) { throw RecursiveMethodCallError(); } void GraphFunction::run(Stack& stack) { get_executor().run(stack); } void GraphFunction::run(Stack&& stack) { run(stack); } c10::intrusive_ptr<c10::ivalue::Future> GraphFunction::runAsync( Stack& stack, TaskLauncher taskLauncher) { return get_executor().runAsync(stack, std::move(taskLauncher)); } IValue GraphFunction::operator()( std::vector<IValue> stack, const Kwargs& kwargs) { getSchema().checkAndNormalizeInputs(stack, kwargs); run(stack); return stack.front(); } void GraphFunction::ensure_defined() { if (function_creator_) { auto creator = function_creator_; function_creator_ = placeholderCreator; creator(*this); function_creator_ = nullptr; } check_single_output(); } const c10::FunctionSchema& GraphFunction::getSchema() const { if (schema_ == nullptr) { schema_ = std::make_unique<c10::FunctionSchema>(defaultSchemaFor(*this)); } return *schema_; } void preoptimizeGraph(std::shared_ptr<Graph>& graph) { Inline(*graph); // Peephole Optimize cleans up many "is None" checks and creates constant prop // opportunities PeepholeOptimize(graph, true); // // AliasDb construction can be slow, so run it just on immutable types // // to clean up constant Ifs & other easy wins ConstantPropagationImmutableTypes(graph); ConstantPooling(graph); } } // namespace jit } // namespace torch