diff --git a/.gitignore b/.gitignore index 5315c68f81..7ebc17fcf1 100644 --- a/.gitignore +++ b/.gitignore @@ -36,6 +36,7 @@ open_spiel/games/universal_poker/double_dummy_solver/ open_spiel/games/hanabi/hanabi-learning-environment/ /open_spiel/pybind11_abseil/ pybind11/ +!open_spiel/python/pybind11 # Install artifacts download_cache/ diff --git a/open_spiel/algorithms/infostate_tree.cc b/open_spiel/algorithms/infostate_tree.cc index 61584e2d0d..df0e0c6435 100644 --- a/open_spiel/algorithms/infostate_tree.cc +++ b/open_spiel/algorithms/infostate_tree.cc @@ -200,15 +200,15 @@ void InfostateTree::CollectNodesAtDepth(InfostateNode* node, size_t depth) { CollectNodesAtDepth(child, depth + 1); } -std::ostream& InfostateTree::operator<<(std::ostream& os) const { - return os << "Infostate tree for player " << acting_player_ << ".\n" - << "Tree height: " << tree_height_ << '\n' - << "Root branching: " << root_branching_factor() << '\n' - << "Number of decision infostate nodes: " << num_decisions() << '\n' - << "Number of sequences: " << num_sequences() << '\n' - << "Number of leaves: " << num_leaves() << '\n' +std::ostream& operator<<(std::ostream& os, const InfostateTree& tree) { + return os << "Infostate tree for player " << tree.acting_player_ << ".\n" + << "Tree height: " << tree.tree_height_ << '\n' + << "Root branching: " << tree.root_branching_factor() << '\n' + << "Number of decision infostate nodes: " << tree.num_decisions() << '\n' + << "Number of sequences: " << tree.num_sequences() << '\n' + << "Number of leaves: " << tree.num_leaves() << '\n' << "Tree certificate: " << '\n' - << root().MakeCertificate() << '\n'; + << tree.root().MakeCertificate() << '\n'; } std::unique_ptr InfostateTree::MakeNode( @@ -485,7 +485,8 @@ absl::optional InfostateTree::DecisionIdForSequence( } } absl::optional InfostateTree::DecisionForSequence( - const SequenceId& sequence_id) { + const SequenceId& sequence_id) const +{ SPIEL_DCHECK_TRUE(sequence_id.BelongsToTree(this)); InfostateNode* node = sequences_.at(sequence_id.id()); SPIEL_DCHECK_TRUE(node); @@ -595,7 +596,7 @@ std::pair InfostateTree::CollectStartEndSequenceIds( } std::pair InfostateTree::BestResponse( - TreeplexVector&& gradient) const { + TreeplexVector gradient) const { SPIEL_CHECK_EQ(this, gradient.tree()); SPIEL_CHECK_EQ(num_sequences(), gradient.size()); SfStrategy response(this); @@ -647,7 +648,7 @@ std::pair InfostateTree::BestResponse( return {gradient[empty_sequence()], response}; } -double InfostateTree::BestResponseValue(LeafVector&& gradient) const { +double InfostateTree::BestResponseValue(LeafVector gradient) const { // Loop over all heights. for (int d = tree_height_ - 1; d >= 0; d--) { int left_offset = 0; diff --git a/open_spiel/algorithms/infostate_tree.h b/open_spiel/algorithms/infostate_tree.h index 2f02dd1041..a4c9fe22ef 100644 --- a/open_spiel/algorithms/infostate_tree.h +++ b/open_spiel/algorithms/infostate_tree.h @@ -237,6 +237,9 @@ class RangeIterator { bool operator!=(const RangeIterator& other) const { return id_ != other.id_ || tree_ != other.tree_; } + bool operator==(const RangeIterator& other) const { + return !(this->operator!=(other)); + } Id operator*() { return Id(id_, tree_); } }; template @@ -285,7 +288,7 @@ std::shared_ptr MakeInfostateTree( const std::vector& start_nodes, int max_move_ahead_limit = 1000); -class InfostateTree final { +class InfostateTree final : public std::enable_shared_from_this { // Note that only MakeInfostateTree is allowed to call the constructor // to ensure the trees are always allocated on heap. We do this so that all // the collected pointers are valid throughout the tree's lifetime even if @@ -305,6 +308,10 @@ class InfostateTree final { const std::vector&, int); public: + // -- gain shared ownership of the allocated infostate object + std::shared_ptr< InfostateTree > shared_ptr() { return shared_from_this(); } + std::shared_ptr< const InfostateTree > shared_ptr() const { return shared_from_this(); } + // -- Root accessors --------------------------------------------------------- const InfostateNode& root() const { return *root_; } InfostateNode* mutable_root() { return root_.get(); } @@ -344,7 +351,7 @@ class InfostateTree final { // Returns `None` if the sequence is the empty sequence. absl::optional DecisionIdForSequence(const SequenceId&) const; // Returns `None` if the sequence is the empty sequence. - absl::optional DecisionForSequence(const SequenceId&); + absl::optional DecisionForSequence(const SequenceId& sequence_id) const; // Returns whether the sequence ends with the last action the player can make. bool IsLeafSequence(const SequenceId&) const; @@ -385,13 +392,13 @@ class InfostateTree final { // Compute best response and value based on gradient from opponents. // This consumes the gradient vector, as it is used to compute the value. std::pair BestResponse( - TreeplexVector&& gradient) const; + TreeplexVector gradient) const; // Compute best response value based on gradient from opponents over leaves. // This consumes the gradient vector, as it is used to compute the value. - double BestResponseValue(LeafVector&& gradient) const; + double BestResponseValue(LeafVector gradient) const; // -- For debugging ---------------------------------------------------------- - std::ostream& operator<<(std::ostream& os) const; + friend std::ostream& operator<<(std::ostream& os, const InfostateTree& tree); private: const Player acting_player_; @@ -550,6 +557,7 @@ class InfostateNode final { const InfostateNodeType& type() const { return type_; } size_t depth() const { return depth_; } bool is_root_node() const { return !parent_; } + bool is_filler_node() const { return infostate_string_ == kFillerInfostate; } bool has_infostate_string() const { return infostate_string_ != kFillerInfostate && infostate_string_ != kDummyRootNodeInfostate; diff --git a/open_spiel/python/CMakeLists.txt b/open_spiel/python/CMakeLists.txt index ca5ed2ece1..0af972dbf1 100644 --- a/open_spiel/python/CMakeLists.txt +++ b/open_spiel/python/CMakeLists.txt @@ -77,6 +77,9 @@ endif() # List of all Python bindings to add to pyspiel. include_directories (../pybind11_abseil ../../pybind11/include) set(PYTHON_BINDINGS ${PYTHON_BINDINGS} + pybind11/algorithms_infostate_tree.cc + pybind11/algorithms_infostate_tree.tcc + pybind11/algorithms_infostate_tree.h pybind11/algorithms_corr_dist.cc pybind11/algorithms_corr_dist.h pybind11/algorithms_trajectories.cc diff --git a/open_spiel/python/pybind11/algorithms_infostate_tree.cc b/open_spiel/python/pybind11/algorithms_infostate_tree.cc new file mode 100644 index 0000000000..af2b54603e --- /dev/null +++ b/open_spiel/python/pybind11/algorithms_infostate_tree.cc @@ -0,0 +1,372 @@ +// Copyright 2021 DeepMind Technologies Limited +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "open_spiel/python/pybind11/algorithms_infostate_tree.h" + +#include + +#include "open_spiel/algorithms/infostate_tree.h" +#include "pybind11/stl_bind.h" + +namespace py = ::pybind11; + +namespace open_spiel { + +using namespace algorithms; + +using infostatenode_holder_ptr = MockUniquePtr< InfostateNode >; +using const_infostatenode_holder_ptr = MockUniquePtr< const InfostateNode >; + +class InfostateNodeChildIterator { + using iter_type = VecWithUniquePtrsIterator< InfostateNode >; + + iter_type iter_; + + public: + explicit InfostateNodeChildIterator(iter_type it) : iter_(it) {} + InfostateNodeChildIterator& operator++() { ++iter_; return *this; } + bool operator==(const InfostateNodeChildIterator &other) const { return iter_ == other.iter_; } + bool operator!=(const InfostateNodeChildIterator &other) const { return ! (*this == other); } + + // this dereferencing operator wrap is the reason for the class adaptor. + // We need to ensure that each node is wrapped in a non owning unique ptr mock. + auto operator*() { return infostatenode_holder_ptr{*iter_}; } + + auto begin() const { return InfostateNodeChildIterator{iter_.begin()}; } + auto end() const { return InfostateNodeChildIterator{iter_.end()}; } +}; + +void init_pyspiel_infostate_node(::pybind11::module &m) +{ + py::class_< InfostateNode, infostatenode_holder_ptr >(m, "InfostateNode", py::is_final()) + .def("tree", [](const InfostateNode &node) { return node.tree().shared_ptr(); }) + .def( + "parent", [](const InfostateNode &node) { return infostatenode_holder_ptr{node.parent()}; } + ) + .def("incoming_index", &InfostateNode::incoming_index) + .def("type", &InfostateNode::type) + .def("depth", &InfostateNode::depth) + .def("is_root_node", &InfostateNode::is_root_node) + .def("is_filler_node", &InfostateNode::is_filler_node) + .def("has_infostate_string", &InfostateNode::has_infostate_string) + .def("infostate_string", &InfostateNode::infostate_string) + .def("num_children", &InfostateNode::num_children) + .def( + "terminal_history", + &InfostateNode::TerminalHistory, + py::return_value_policy::reference_internal + ) + .def("sequence_id", &InfostateNode::sequence_id) + .def("start_sequence_id", &InfostateNode::start_sequence_id) + .def("end_sequence_id", &InfostateNode::end_sequence_id) + .def("all_sequence_ids", &InfostateNode::AllSequenceIds) + .def("decision_id", &InfostateNode::decision_id) + .def( + "legal_actions", &InfostateNode::legal_actions, py::return_value_policy::reference_internal + ) + .def("is_leaf_node", &InfostateNode::is_leaf_node) + .def("terminal_utility", &InfostateNode::terminal_utility) + .def("terminal_chance_reach_prob", &InfostateNode::terminal_chance_reach_prob) + .def("corresponding_states_size", &InfostateNode::corresponding_states_size) + .def( + "corresponding_states", + &InfostateNode::corresponding_states, + py::return_value_policy::reference_internal + ) + .def( + "corresponding_chance_reach_probs", + &InfostateNode::corresponding_chance_reach_probs, + py::return_value_policy::reference_internal + ) + .def( + "child_at", + [](const InfostateNode &node, int index) { + return infostatenode_holder_ptr{node.child_at(index)}; + }, + py::arg("index") + ) + .def("make_certificate", &InfostateNode::MakeCertificate) + .def( + "address_str", + [](const InfostateNode &node) { + std::stringstream ss; + ss << &node; + return ss.str(); + } + ) + .def( + "__iter__", + [](const InfostateNode &node) { + return py::make_iterator( + InfostateNodeChildIterator{node.child_iterator().begin()}, + InfostateNodeChildIterator{node.child_iterator().end()} + ); + } + ) + .def( + "__copy__", + [](const InfostateNode &node) { + throw ForbiddenException( + "InfostateNode cannot be copied, because its " + "lifetime is managed by the owning " + "InfostateTree. Store a variable naming the " + "associated tree to ensure the node's " + "lifetime." + ); + } + ) + .def("__deepcopy__", [](const InfostateNode &node) { + throw ForbiddenException( + "InfostateNode cannot be copied, because its " + "lifetime is managed by the owning " + "InfostateTree. Store a variable naming the " + "associated tree to ensure the node's " + "lifetime." + ); + }); + + py::enum_< InfostateNodeType >(m, "InfostateNodeType") + .value("decision", InfostateNodeType::kDecisionInfostateNode) + .value("observation", InfostateNodeType::kObservationInfostateNode) + .value("terminal", InfostateNodeType::kTerminalInfostateNode) + .export_values(); +} + +struct ToHolderPtrFunctor { + auto operator()(InfostateNode *ptr) const noexcept { return infostatenode_holder_ptr{ptr}; } +}; + +template < + typename ContainerOut, + typename TransformFunctor = ToHolderPtrFunctor, + template < class... > class Container = std::vector, + typename... RestTs > +auto to(const Container< RestTs... > &node_container, TransformFunctor transformer) +{ + ContainerOut internal_vec{}; + internal_vec.reserve(node_container.size()); + std::transform( + node_container.begin(), + node_container.end(), + std::back_insert_iterator(internal_vec), + transformer + ); + return internal_vec; +} + +void init_pyspiel_infostate_tree(::pybind11::module &m) +{ + // Infostate-Tree nodes and NodeType enum + init_pyspiel_infostate_node(m); + // suffix is float despite using double, since python's floating point type + // is double precision. + init_pyspiel_treevector_bundle< double >(m, "Float"); + // a generic tree vector bundle holding any type of python object + init_pyspiel_treevector_bundle< py::object >(m, ""); + // bind a range for every possible id type + init_pyspiel_range< SequenceId >(m, "SequenceIdRange"); + init_pyspiel_range< DecisionId >(m, "DecisionIdRange"); + init_pyspiel_range< LeafId >(m, "LeafIdRange"); + + init_pyspiel_node_id< SequenceId >(m, "SequenceId"); + init_pyspiel_node_id< DecisionId >(m, "DecisionId"); + init_pyspiel_node_id< LeafId >(m, "LeafId"); + + m.attr("UNDEFINED_DECISION_ID") = ::pybind11::cast(kUndefinedDecisionId); + m.attr("UNDEFINED_LEAF_ID") = ::pybind11::cast(kUndefinedLeafId); + m.attr("UNDEFINED_SEQUENCE_ID") = ::pybind11::cast(kUndefinedSequenceId); + m.attr("DUMMY_ROOT_NODE_INFOSTATE") = ::pybind11::cast(algorithms::kDummyRootNodeInfostate); + m.attr("FILLER_INFOSTATE") = ::pybind11::cast(algorithms::kFillerInfostate); + + m.def("is_valid_sf_strategy", &IsValidSfStrategy); + + py::bind_vector< std::vector< std::vector< infostatenode_holder_ptr > > >( + m, "InfostateNodeVector2D" + ); + + py::class_< InfostateTree, std::shared_ptr< InfostateTree > >(m, "InfostateTree", py::is_final()) + .def( + py::init([](const Game &game, Player acting_player, int max_move_limit) { + return MakeInfostateTree(game, acting_player, max_move_limit); + }), + py::arg("game"), + py::arg("acting_player"), + py::arg("max_move_limit") = 1000 + ) + .def( + py::init([](const std::vector< const State * > &start_states, + const std::vector< double > &chance_reach_probs, + std::shared_ptr< Observer > infostate_observer, + Player acting_player, + int max_move_ahead_limit) { + return MakeInfostateTree( + start_states, + chance_reach_probs, + std::move(infostate_observer), + acting_player, + max_move_ahead_limit + ); + }), + py::arg("start_states"), + py::arg("chance_reach_probs"), + py::arg("infostate_observer"), + py::arg("acting_player"), + py::arg("max_move_limit") = 1000 + ) + .def( + py::init([](const std::vector< const InfostateNode * > &start_nodes, + int max_move_ahead_limit) { + return MakeInfostateTree(start_nodes, max_move_ahead_limit); + }), + py::arg("start_nodes"), + py::arg("max_move_limit") = 1000 + ) + .def( + "root", [](InfostateTree &tree) { return infostatenode_holder_ptr{tree.mutable_root()}; } + ) + .def("root_branching_factor", &InfostateTree::root_branching_factor) + .def("acting_player", &InfostateTree::acting_player) + .def("tree_height", &InfostateTree::tree_height) + .def("num_decisions", &InfostateTree::num_decisions) + .def("num_sequences", &InfostateTree::num_sequences) + .def("num_leaves", &InfostateTree::num_leaves) + .def("empty_sequence", &InfostateTree::empty_sequence) + .def( + "observation_infostate", + [](const InfostateTree &tree, const SequenceId &id) { + return const_infostatenode_holder_ptr{tree.observation_infostate(id)}; + }, + py::arg("sequence_id") + ) + .def("all_sequence_ids", &InfostateTree::AllSequenceIds) + .def( + "decision_ids_with_parent_seq", + &InfostateTree::DecisionIdsWithParentSeq, + py::arg("sequence_id") + ) + .def( + "decision_id_for_sequence", &InfostateTree::DecisionIdForSequence, py::arg("sequence_id") + ) + .def( + "decision_for_sequence", + [](InfostateTree &tree, const SequenceId &id) { + auto node_opt = tree.DecisionForSequence(id); + if(not node_opt.has_value()) { + return absl::optional< infostatenode_holder_ptr >{}; + } else { + return absl::optional< infostatenode_holder_ptr >{ + infostatenode_holder_ptr{*node_opt}}; + } + }, + py::arg("sequence_id") + ) + .def("is_leaf_sequence", &InfostateTree::IsLeafSequence) + .def( + "decision_infostate", + [](const InfostateTree &tree, const DecisionId &id) { + return const_infostatenode_holder_ptr{tree.decision_infostate(id)}; + }, + py::arg("decision_id") + ) + .def( + "all_decision_infostates", + [](const InfostateTree &tree) { + return to< std::vector< infostatenode_holder_ptr > >( + tree.AllDecisionInfostates(), ToHolderPtrFunctor{} + ); + } + ) + .def("all_decision_ids", &InfostateTree::AllDecisionIds) + .def( + "decision_id_from_infostate_string", + &InfostateTree::DecisionIdFromInfostateString, + py::arg("infostate_string") + ) + .def( + "leaf_nodes", + [](const InfostateTree &tree) { + return to< std::vector< infostatenode_holder_ptr > >( + tree.leaf_nodes(), ToHolderPtrFunctor{} + ); + } + ) + .def( + "leaf_node", + [](const InfostateTree &tree, const LeafId &id) { + return infostatenode_holder_ptr{tree.leaf_node(id)}; + }, + py::arg("leaf_id") + ) + .def( + "nodes_at_depths", + [](const InfostateTree &tree) { + return to< std::vector< std::vector< infostatenode_holder_ptr > > >( + tree.nodes_at_depths(), + [](const auto &internal_vec) { + return to< std::vector< infostatenode_holder_ptr > >( + internal_vec, ToHolderPtrFunctor{} + ); + } + ); + } + ) + .def( + "nodes_at_depth", + [](const InfostateTree &tree, const py::int_ &depth) { + // we accept a py::int_ here instead of directly asking for a + // size_t, since whatever pybind11 would cast to size_t in order to + // fulifll the type requirement would simply be byte-cast into + // size_t. This would turn negative values into high integers, + // instead of throwing an error. + if(depth < py::int_(0)) { + throw std::invalid_argument("'depth' must be non-negative."); + } + // convert the raw node vector again into a vector of non-deleting + // node unique pointer. + return to< std::vector< infostatenode_holder_ptr > >( + tree.nodes_at_depth(py::cast< size_t >(depth)), ToHolderPtrFunctor{} + ); + }, + py::arg("depth") + ) + .def("best_response", &InfostateTree::BestResponse, py::arg("gradient")) + .def("best_response_value", &InfostateTree::BestResponseValue, py::arg("gradient")) + .def( + "__repr__", + [](const InfostateTree &tree) { + std::ostringstream oss; + oss << tree; + return oss.str(); + } + ) + .def( + "__copy__", + [](const InfostateTree &) { + throw ForbiddenException( + "InfostateTree cannot be copied, because its " + "internal structure is entangled during construction. " + "Create a new tree instead." + ); + } + ) + .def("__deepcopy__", [](const InfostateTree &) { + throw ForbiddenException( + "InfostateTree cannot be copied, because its " + "internal structure is entangled during construction. " + "Create a new tree instead." + ); + }); +} + +} // namespace open_spiel diff --git a/open_spiel/python/pybind11/algorithms_infostate_tree.h b/open_spiel/python/pybind11/algorithms_infostate_tree.h new file mode 100644 index 0000000000..e91413f6e9 --- /dev/null +++ b/open_spiel/python/pybind11/algorithms_infostate_tree.h @@ -0,0 +1,99 @@ +// Copyright 2021 DeepMind Technologies Limited +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef OPEN_SPIEL_PYTHON_PYBIND11_INFOSTATE_TREE_H +#define OPEN_SPIEL_PYTHON_PYBIND11_INFOSTATE_TREE_H + +#include "open_spiel/python/pybind11/pybind11.h" +#include "pybind11_abseil/absl_casters.h" + +namespace open_spiel { + +void init_pyspiel_infostate_tree(::pybind11::module &m); + +void init_pyspiel_infostate_node(::pybind11::module &m); + +template < typename T > +void init_pyspiel_treevector_bundle(::pybind11::module &m, std::string &typestr); + +template < typename Self > +void init_pyspiel_node_id(::pybind11::module &m, const std::string &class_name); + +// Bind the Range class +template < class Id > +void init_pyspiel_range(::pybind11::module &m, const std::string &name); + +} // namespace open_spiel + +// include the template definition file +#include "open_spiel/python/pybind11/algorithms_infostate_tree.tcc" + +/// An exception wrapping a forbidden action with given reason. +class ForbiddenException: public std::exception { + public: + explicit ForbiddenException(const char *reason) : m_reason(reason) {} + + [[nodiscard]] const char *what() const noexcept override { return m_reason.c_str(); } + + private: + std::string m_reason; +}; + +/// A smart holder that mimicks the unique pointer api, but doesn't delete the contained object. +/// +/// This class is used for epxosing c++ objects with c++ maintained lifetimes on the python side +/// without running into the risk of double free. +/// While std::unique_ptr< T, py::nodelete> would fulfill the same, such a pointer is not copyable +/// and thus prohibits other bindings, e.g. bind_vector of such a pointer. +/// The MockUniquePtr is copyable, since it doesn't manage any lifetime, and can therefore be used +/// more easily. +template < typename T > +class MockUniquePtr { + public: + MockUniquePtr() noexcept : ptr_(nullptr) {} + explicit MockUniquePtr(T *ptr) noexcept : ptr_(ptr) {} + ~MockUniquePtr() = default; + MockUniquePtr(const MockUniquePtr &other) noexcept : ptr_(other.get()) {} + MockUniquePtr &operator=(const MockUniquePtr &other) noexcept + { + reset(other.get()); + return *this; + } + MockUniquePtr(MockUniquePtr &&other) noexcept : ptr_(other.release()) {} + MockUniquePtr &operator=(MockUniquePtr &&other) noexcept + { + reset(other.release()); + return *this; + } + + [[nodiscard]] T *get() const noexcept { return ptr_; } + T *release() noexcept + { + T *ptr = ptr_; + ptr_ = nullptr; + return ptr; + } + void reset(T *ptr = nullptr) noexcept { ptr_ = ptr; } + + T &operator*() const noexcept { return *ptr_; } + T *operator->() const noexcept { return ptr_; } + explicit operator bool() const noexcept { return ptr_ != nullptr; } + + private: + T *ptr_; +}; + +PYBIND11_DECLARE_HOLDER_TYPE(T, MockUniquePtr< T >, true); + +#endif // OPEN_SPIEL_PYTHON_PYBIND11_INFOSTATE_TREE_H diff --git a/open_spiel/python/pybind11/algorithms_infostate_tree.tcc b/open_spiel/python/pybind11/algorithms_infostate_tree.tcc new file mode 100644 index 0000000000..af1c4779f5 --- /dev/null +++ b/open_spiel/python/pybind11/algorithms_infostate_tree.tcc @@ -0,0 +1,110 @@ +// Copyright 2021 DeepMind Technologies Limited +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef OPEN_SPIEL_PYTHON_PYBIND11_INFOSTATE_TREE_TCC +#define OPEN_SPIEL_PYTHON_PYBIND11_INFOSTATE_TREE_TCC + +#include "open_spiel/algorithms/infostate_tree.h" +#include "open_spiel/python/pybind11/algorithms_infostate_tree.h" + +namespace open_spiel { + +using namespace algorithms; + +namespace detail { +template < typename T, template < typename > class TreeVectorDerived, typename IdType > +void _init_pyspiel_treevector_bundle_impl( + ::pybind11::module &m, + const std::string &template_name, + const std::string &type_name +) +{ + ::pybind11::class_< TreeVectorDerived< T >, std::shared_ptr< TreeVectorDerived< T > > >( + m, (template_name + type_name).c_str() + ) + .def(::pybind11::init< const InfostateTree * >(), ::pybind11::arg("tree")) + .def( + ::pybind11::init< const InfostateTree *, std::vector< T > >(), + ::pybind11::arg("tree"), + ::pybind11::arg("vec") + ) + .def( + "__getitem__", + [](const TreeVectorDerived< T > &self, const IdType &id) { return self[id]; }, + ::pybind11::arg("id") + ) + .def("__len__", &TreeVectorDerived< T >::size) + .def("__repr__", &TreeVectorDerived< T >::operator<<) + .def( + "__copy__", + [](const std::shared_ptr< TreeVectorDerived< T > > &self) { + return std::shared_ptr< TreeVectorDerived< T > >(self); + } + ) + .def("__deepcopy__", [](const TreeVectorDerived< T > &self) { + return TreeVectorDerived< T >{self}; + }); +} +} // namespace detail + +template < typename T > +void init_pyspiel_treevector_bundle(::pybind11::module &m, const std::string &typestr) +{ + detail::_init_pyspiel_treevector_bundle_impl< T, TreeplexVector, SequenceId >( + m, "TreeplexVector", typestr + ); + detail::_init_pyspiel_treevector_bundle_impl< T, LeafVector, LeafId >(m, "LeafVector", typestr); + detail::_init_pyspiel_treevector_bundle_impl< T, DecisionVector, DecisionId >( + m, "DecisionVector", typestr + ); +} + +template < typename Self > +void init_pyspiel_node_id(::pybind11::module &m, const std::string &class_name) +{ + ::pybind11::class_< Self >(m, class_name.c_str()) + .def( + ::pybind11::init< size_t, const InfostateTree * >(), + ::pybind11::arg("id_value"), + ::pybind11::arg("tree") + ) + .def("id", &Self::id) + .def("is_undefined", &Self::is_undefined) + .def("next", &Self::next) + .def("__eq__", &Self::operator==) + .def("__ne__", &Self::operator!=) + .def("__copy__", [](const Self &self) { return Self(self); }) + .def("__deepcopy__", [](const Self &self) { return Self(self); }); +} + +template < class Id > +void init_pyspiel_range(::pybind11::module &m, const std::string &name) +{ + ::pybind11::class_< Range< Id > >(m, name.c_str()) + .def( + ::pybind11::init< size_t, size_t, const InfostateTree * >(), + ::pybind11::arg("start"), + ::pybind11::arg("end"), + ::pybind11::arg("tree") + ) + .def( + "__iter__", + [](const Range< Id > &r) { return ::pybind11::make_iterator(r.begin(), r.end()); }, + ::pybind11::keep_alive< 0, 1 >() + ); +} + +} // namespace open_spiel + +#endif // OPEN_SPIEL_PYTHON_PYBIND11_INFOSTATE_TREE_TCC \ No newline at end of file diff --git a/open_spiel/python/pybind11/algorithms_trajectories.cc b/open_spiel/python/pybind11/algorithms_trajectories.cc index 0a31e6d118..b3cda0d99a 100644 --- a/open_spiel/python/pybind11/algorithms_trajectories.cc +++ b/open_spiel/python/pybind11/algorithms_trajectories.cc @@ -54,7 +54,7 @@ void init_pyspiel_algorithms_trajectories(py::module& m) { m.def( "record_batched_trajectories", - [](std::shared_ptr game, + [](const std::shared_ptr& game, const std::vector& policies, const std::unordered_map& state_to_index, int batch_size, bool include_full_observations, int seed, @@ -68,7 +68,7 @@ void init_pyspiel_algorithms_trajectories(py::module& m) { py::class_(m, "TrajectoryRecorder") .def(py::init( - [](std::shared_ptr game, + [](const std::shared_ptr& game, const std::unordered_map& state_to_index, int seed) { return new algorithms::TrajectoryRecorder(*game, state_to_index, diff --git a/open_spiel/python/pybind11/pyspiel.cc b/open_spiel/python/pybind11/pyspiel.cc index a85bea1847..c04fb1d400 100644 --- a/open_spiel/python/pybind11/pyspiel.cc +++ b/open_spiel/python/pybind11/pyspiel.cc @@ -29,6 +29,7 @@ #include "open_spiel/observer.h" #include "open_spiel/python/pybind11/algorithms_corr_dist.h" #include "open_spiel/python/pybind11/algorithms_trajectories.h" +#include "open_spiel/python/pybind11/algorithms_infostate_tree.h" #include "open_spiel/python/pybind11/bots.h" #include "open_spiel/python/pybind11/game_transforms.h" #include "open_spiel/python/pybind11/games_backgammon.h" @@ -624,6 +625,7 @@ PYBIND11_MODULE(pyspiel, m) { throw SpielException(string); }); py::register_exception(m, "SpielError", PyExc_RuntimeError); + py::register_exception(m, "ForbiddenError", PyExc_RuntimeError); // Register other bits of the API. init_pyspiel_bots(m); // Bots and bot-related algorithms. @@ -647,6 +649,7 @@ PYBIND11_MODULE(pyspiel, m) { init_pyspiel_games_trade_comm(m); // Game-specific functions for trade_comm. init_pyspiel_observer(m); // Observers and observations. init_pyspiel_utils(m); // Utilities. + init_pyspiel_infostate_tree(m); // Infostate-Tree and associated classes (Id etc.) // List of optional python submodules. #if OPEN_SPIEL_BUILD_WITH_GAMUT diff --git a/open_spiel/python/tests/infostate_tree_test.py b/open_spiel/python/tests/infostate_tree_test.py new file mode 100644 index 0000000000..8041639ac8 --- /dev/null +++ b/open_spiel/python/tests/infostate_tree_test.py @@ -0,0 +1,342 @@ +# Copyright 2019 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test Python bindings for infostate tree and related classes.""" + +from absl.testing import absltest, parameterized + +import pyspiel +import gc +from copy import copy, deepcopy +import weakref +from typing import Iterator, Iterable + + +class InfostateTreeTest(parameterized.TestCase): + def test_tree_binding(self): + game = pyspiel.load_game("kuhn_poker") + tree = pyspiel.InfostateTree(game, 0) + self.assertEqual(tree.num_sequences(), 13) + + # disallowing copying is enforced + with self.assertRaises(pyspiel.ForbiddenError) as context: + copy(tree) + deepcopy(tree) + + def test_node_tree_lifetime_management(self): + game = pyspiel.load_game("kuhn_poker") + tree = pyspiel.InfostateTree(game, 0) + root = tree.root() + # let's maintain a weak ref to the tree and node to see when the tree and node objects are deallocated + wptr = weakref.ref(tree) + wptr_node = weakref.ref(root) + + # ensure that deleting a node does not delete the underlying object + del root + gc.collect() + # assert the weakref thinks the object is gone + self.assertIsNone(wptr_node()) + # but the tree still holds the actual c++ sided object + root = tree.root() + wptr_node = weakref.ref(root) + self.assertIsNotNone(wptr_node()) + # ensure we can get a shared_ptr from root that keeps tree alive if we lose the 'tree' name + tree_sptr = root.tree() + # grab the tree id + id_tree = id(tree) + # now delete the initial tree ptr + del tree + # ensure that we still hold the object + gc.collect() # force garbage collection + self.assertIsNotNone(wptr()) + self.assertEqual(id(tree_sptr), id_tree) + # now delete the last pointer as well + del tree_sptr + gc.collect() # force garbage collection + self.assertIsNone(wptr()) + + @parameterized.parameters( + [ + # test for matrix mp + dict( + game=pyspiel.load_game("matrix_mp"), + players=[0, 1], + expected_certificate="([" "({}{})" "({}{})" "])", + ), + # test for imperfect info goofspiel + dict( + game=pyspiel.load_game( + "goofspiel", + {"num_cards": 2, "imp_info": True, "points_order": "ascending"}, + ), + players=[0, 1], + expected_certificate="([" "({}{})" "({}{})" "])", + ), + # test for kuhn poker (0 player only) + dict( + game=pyspiel.load_game("kuhn_poker"), + players=[0], + expected_certificate=( + "((" # Root node, 1st is getting a card + "(" # 2nd is getting card + "[" # 1st acts + "((" # 1st bet, and 2nd acts + "(({}))" + "(({}))" + "(({}))" + "(({}))" + "))" + "((" # 1st checks, and 2nd acts + # 2nd checked + "(({}))" + "(({}))" + # 2nd betted + "[({}" + "{})" + "({}" + "{})]" + "))" + "]" + ")" + # Just 2 more copies. + "([(((({}))(({}))(({}))(({}))))(((({}))(({}))[({}{})({}{})]))])" + "([(((({}))(({}))(({}))(({}))))(((({}))(({}))[({}{})({}{})]))])" + "))" + ), + ), + ] + ) + def test_root_certificates(self, game, players, expected_certificate): + for i in players: + tree = pyspiel.InfostateTree(game, i) + self.assertEqual(tree.root().make_certificate(), expected_certificate) + + def check_tree_leaves(self, tree, move_limit): + for leaf_node in tree.leaf_nodes(): + self.assertTrue(leaf_node.is_leaf_node()) + self.assertTrue(leaf_node.has_infostate_string()) + self.assertNotEmpty(leaf_node.corresponding_states()) + + num_states = len(leaf_node.corresponding_states()) + terminal_cnt = 0 + max_move_number = float("-inf") + min_move_number = float("inf") + for state in leaf_node.corresponding_states(): + if state.is_terminal(): + terminal_cnt += 1 + max_move_number = max(max_move_number, state.move_number()) + min_move_number = min(min_move_number, state.move_number()) + self.assertTrue(terminal_cnt == 0 or terminal_cnt == num_states) + self.assertTrue(max_move_number == min_move_number) + if terminal_cnt == 0: + self.assertEqual(max_move_number, move_limit) + else: + self.assertLessEqual(max_move_number, move_limit) + + def check_continuation(self, tree): + leaves = tree.nodes_at_depth(tree.tree_height()) + continuation = pyspiel.InfostateTree(leaves) + self.assertEqual(continuation.root_branching_factor(), len(leaves)) + for i in range(len(leaves)): + leaf_node = leaves[i] + root_node = continuation.root().child_at(i) + self.assertTrue(leaf_node.is_leaf_node()) + if leaf_node.type() != pyspiel.InfostateNodeType.terminal: + self.assertEqual(leaf_node.type(), root_node.type()) + self.assertEqual( + leaf_node.has_infostate_string(), root_node.has_infostate_string() + ) + if leaf_node.has_infostate_string(): + self.assertEqual( + leaf_node.infostate_string(), root_node.infostate_string() + ) + else: + terminal_continuation = continuation.root().child_at(i) + while ( + terminal_continuation.type() + == pyspiel.InfostateNodeType.observation + ): + self.assertFalse(terminal_continuation.is_leaf_node()) + self.assertEqual(terminal_continuation.num_children(), 1) + terminal_continuation = terminal_continuation.child_at(0) + self.assertEqual( + terminal_continuation.type(), pyspiel.InfostateNodeType.terminal + ) + self.assertEqual( + leaf_node.has_infostate_string(), + terminal_continuation.has_infostate_string(), + ) + if leaf_node.has_infostate_string(): + self.assertEqual( + leaf_node.infostate_string(), + terminal_continuation.infostate_string(), + ) + self.assertEqual( + leaf_node.terminal_utility(), + terminal_continuation.terminal_utility(), + ) + self.assertEqual( + leaf_node.terminal_chance_reach_prob(), + terminal_continuation.terminal_chance_reach_prob(), + ) + self.assertEqual( + leaf_node.terminal_history(), + terminal_continuation.terminal_history(), + ) + + def test_depth_limited_tree_kuhn_poker(self): + # Test MakeTree for Kuhn Poker with depth limit 2 + expected_certificate = ( + "(" # + "(" # 1st is getting a card + "(" # 2nd is getting card + "[" # 1st acts - Node J + # Depth cutoff. + "]" + ")" + # Repeat the same for the two other cards. + "([])" # Node Q + "([])" # Node K + ")" + ")" # + ) + tree = pyspiel.InfostateTree(pyspiel.load_game("kuhn_poker"), 0, 2) + self.assertEqual(tree.root().make_certificate(), expected_certificate) + + # Test leaf nodes in Kuhn Poker tree + for acting in tree.leaf_nodes(): + self.assertTrue(acting.is_leaf_node()) + self.assertEqual(acting.type(), pyspiel.InfostateNodeType.decision) + self.assertEqual(len(acting.corresponding_states()), 2) + self.assertTrue(acting.has_infostate_string()) + + @parameterized.parameters( + [ + "kuhn_poker", + "kuhn_poker(players=3)", + "leduc_poker", + "goofspiel(players=2,num_cards=3,imp_info=True)", + "goofspiel(players=3,num_cards=3,imp_info=True)", + ] + ) + def test_depth_limited_trees_all_depths(self, game_name): + game = pyspiel.load_game(game_name) + max_moves = game.max_move_number() + for move_limit in range(max_moves): + for pl in range(game.num_players()): + tree = pyspiel.InfostateTree(game, pl, move_limit) + self.check_tree_leaves(tree, move_limit) + self.check_continuation(tree) + + def test_node_binding(self): + with self.assertRaises(TypeError) as context: + pyspiel.InfostateNode() + self.assertTrue("No constructor defined" in context.exception) + # disallowing copying is enforced + tree = pyspiel.InfostateTree(pyspiel.load_game("kuhn_poker"), 0) + root = tree.root() + with self.assertRaises(pyspiel.ForbiddenError) as context: + copy(root) + deepcopy(root) + + self.assertIsInstance(root, Iterable) + self.assertIsInstance(iter(root), Iterator) + self.assertEqual(root, next(iter(root)).parent()) + + for child in root: + pass + + def test_treevector_binding(self): + game = pyspiel.load_game("kuhn_poker") + tree = pyspiel.InfostateTree(game, 0) + # ensure constructors are bound with the respective args + treeplex_vec = pyspiel.TreeplexVector(tree) + leaf_vec = pyspiel.LeafVector(tree) + decision_vec = pyspiel.DecisionVector(tree) + + self.assertEqual(len(treeplex_vec), 13) + self.assertEqual(len(leaf_vec), 30) + self.assertEqual(len(decision_vec), 6) + + tree.all_decision_ids() + seq_id_range = tree.all_sequence_ids() + n_ids = 0 + for id_ in seq_id_range: + n_ids += 1 + self.assertEqual(n_ids, 13) + seq_id = next(iter(tree.all_sequence_ids())) + seq_id_copy = copy(seq_id) + self.assertEqual(seq_id.id(), 0) + self.assertFalse(seq_id.is_undefined()) + self.assertIsNone(seq_id.next()) + self.assertNotEqual(seq_id, seq_id_copy) + + def test_sequence_id_labeling(self): + for pl in range(2): + tree = pyspiel.InfostateTree(pyspiel.load_game("kuhn_poker"), pl) + + for depth in range(tree.tree_height() + 1): + for node in tree.nodes_at_depth(depth): + self.assertLessEqual( + node.start_sequence_id().id(), node.sequence_id().id() + ) + self.assertLessEqual( + node.end_sequence_id().id(), node.sequence_id().id() + ) + + # Check labeling was done from the deepest nodes. + depth = float("inf") # Some large number. + for id in tree.all_sequence_ids(): + node = tree.observation_infostate(id) + self.assertLessEqual(node.depth(), depth) + depth = node.depth() + # Longer sequences (extensions) must have the corresponding + # infostate nodes placed deeper. + for extension in node.all_sequence_ids(): + child = tree.observation_infostate(extension) + self.assertLess(node.depth(), child.depth()) + + def test_best_response(self): + tree0 = pyspiel.InfostateTree(pyspiel.load_game("matrix_mp"), 0) + tree1 = pyspiel.InfostateTree(pyspiel.load_game("matrix_mp"), 1) + for alpha in range(0, 10): + alpha /= 10.0 + br_value = max(2 * alpha - 1, -2 * alpha + 1) + grad0 = pyspiel.LeafVectorFloat( + tree0, + [1.0 * alpha, -1.0 * (1.0 - alpha), -1.0 * alpha, 1.0 * (1.0 - alpha)], + ) + self.assertAlmostEqual(tree0.best_response_value(grad0), br_value) + + grad1 = pyspiel.LeafVectorFloat( + tree1, + [-1.0 * alpha, 1.0 * (1.0 - alpha), 1.0 * alpha, -1.0 * (1.0 - alpha)], + ) + self.assertAlmostEqual(tree1.best_response_value(grad1), br_value) + + grad0_tp = pyspiel.TreeplexVectorFloat( + tree0, [-1.0 + 2.0 * alpha, 1.0 - 2.0 * alpha, 0.0] + ) + actual_response = tree0.best_response(grad0_tp) + self.assertAlmostEqual(actual_response[0], br_value) + + grad1_tp = pyspiel.TreeplexVectorFloat( + tree1, [1.0 - 2.0 * alpha, -1.0 + 2.0 * alpha, 0.0] + ) + actual_response = tree1.best_response(grad1_tp) + self.assertAlmostEqual(actual_response[0], br_value) + + +if __name__ == "__main__": + absltest.main()