Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

InfostateTree python bind #1054

Open
wants to merge 22 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
581f58b
initial draft binds for parts of infostate file
maichmueller Apr 21, 2023
b5ea161
addition of remaining binding drafts. still wip
maichmueller Apr 21, 2023
457e8db
undo clang-format changes to existing infostate tree file
maichmueller Apr 21, 2023
c8ffdff
added mock uniq ptr, compiling python module after bind changes
maichmueller Apr 21, 2023
1109d78
further binding corrections, integration into pyspiel lib build
maichmueller Apr 21, 2023
70e8e9b
remove erroneously added clang-format
maichmueller Apr 21, 2023
bf89e4b
removing unnecessary formatting compared to base master
maichmueller Apr 21, 2023
9e9e410
add white space to reduce diff
maichmueller Apr 21, 2023
e2957b9
remaining white spaces added
maichmueller Apr 21, 2023
48c1883
add remaining fields of infostatenode
maichmueller Apr 21, 2023
1ee07e2
removed property for bindings from infostate node
maichmueller Apr 21, 2023
c1ae08b
remove duplicate tree constructor
maichmueller Apr 21, 2023
12b7139
further test additions, added enable_shared_from_this to istatetree
maichmueller Apr 21, 2023
7247f34
update tree test bugs
maichmueller Apr 21, 2023
8a2d15f
add py args to various methods
maichmueller Apr 21, 2023
9bbfcb1
restructure binding code a little, add futher py args
maichmueller Apr 21, 2023
3ae69a8
add copyability to bindings, python tests added
maichmueller Apr 21, 2023
c5b4651
refactor util func, remove residual std::unique_ptr return
maichmueller Apr 27, 2023
deccdb3
fix infostate node child iterator bug
maichmueller Apr 27, 2023
1459280
tiny bug fix in infostate node child iterator adaptor
maichmueller May 12, 2023
5702838
Merge branch 'master' into feat/infostate_binding
maichmueller May 16, 2023
6b8aee0
add filler node check to infostate node
maichmueller May 16, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ open_spiel/games/universal_poker/double_dummy_solver/
open_spiel/games/hanabi/hanabi-learning-environment/
/open_spiel/pybind11_abseil/
pybind11/
!open_spiel/python/pybind11

# Install artifacts
download_cache/
Expand Down
23 changes: 12 additions & 11 deletions open_spiel/algorithms/infostate_tree.cc
Original file line number Diff line number Diff line change
Expand Up @@ -200,15 +200,15 @@ void InfostateTree::CollectNodesAtDepth(InfostateNode* node, size_t depth) {
CollectNodesAtDepth(child, depth + 1);
}

std::ostream& InfostateTree::operator<<(std::ostream& os) const {
return os << "Infostate tree for player " << acting_player_ << ".\n"
<< "Tree height: " << tree_height_ << '\n'
<< "Root branching: " << root_branching_factor() << '\n'
<< "Number of decision infostate nodes: " << num_decisions() << '\n'
<< "Number of sequences: " << num_sequences() << '\n'
<< "Number of leaves: " << num_leaves() << '\n'
std::ostream& operator<<(std::ostream& os, const InfostateTree& tree) {
return os << "Infostate tree for player " << tree.acting_player_ << ".\n"
<< "Tree height: " << tree.tree_height_ << '\n'
<< "Root branching: " << tree.root_branching_factor() << '\n'
<< "Number of decision infostate nodes: " << tree.num_decisions() << '\n'
<< "Number of sequences: " << tree.num_sequences() << '\n'
<< "Number of leaves: " << tree.num_leaves() << '\n'
<< "Tree certificate: " << '\n'
<< root().MakeCertificate() << '\n';
<< tree.root().MakeCertificate() << '\n';
}

std::unique_ptr<InfostateNode> InfostateTree::MakeNode(
Expand Down Expand Up @@ -485,7 +485,8 @@ absl::optional<DecisionId> InfostateTree::DecisionIdForSequence(
}
}
absl::optional<InfostateNode*> InfostateTree::DecisionForSequence(
const SequenceId& sequence_id) {
const SequenceId& sequence_id) const
{
SPIEL_DCHECK_TRUE(sequence_id.BelongsToTree(this));
InfostateNode* node = sequences_.at(sequence_id.id());
SPIEL_DCHECK_TRUE(node);
Expand Down Expand Up @@ -595,7 +596,7 @@ std::pair<size_t, size_t> InfostateTree::CollectStartEndSequenceIds(
}

std::pair<double, SfStrategy> InfostateTree::BestResponse(
TreeplexVector<double>&& gradient) const {
TreeplexVector<double> gradient) const {
SPIEL_CHECK_EQ(this, gradient.tree());
SPIEL_CHECK_EQ(num_sequences(), gradient.size());
SfStrategy response(this);
Expand Down Expand Up @@ -647,7 +648,7 @@ std::pair<double, SfStrategy> InfostateTree::BestResponse(
return {gradient[empty_sequence()], response};
}

double InfostateTree::BestResponseValue(LeafVector<double>&& gradient) const {
double InfostateTree::BestResponseValue(LeafVector<double> gradient) const {
// Loop over all heights.
for (int d = tree_height_ - 1; d >= 0; d--) {
int left_offset = 0;
Expand Down
18 changes: 13 additions & 5 deletions open_spiel/algorithms/infostate_tree.h
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,9 @@ class RangeIterator {
bool operator!=(const RangeIterator& other) const {
return id_ != other.id_ || tree_ != other.tree_;
}
bool operator==(const RangeIterator& other) const {
return !(this->operator!=(other));
}
Id operator*() { return Id(id_, tree_); }
};
template <class Id>
Expand Down Expand Up @@ -285,7 +288,7 @@ std::shared_ptr<InfostateTree> MakeInfostateTree(
const std::vector<InfostateNode*>& start_nodes,
int max_move_ahead_limit = 1000);

class InfostateTree final {
class InfostateTree final : public std::enable_shared_from_this<InfostateTree> {
// Note that only MakeInfostateTree is allowed to call the constructor
// to ensure the trees are always allocated on heap. We do this so that all
// the collected pointers are valid throughout the tree's lifetime even if
Expand All @@ -305,6 +308,10 @@ class InfostateTree final {
const std::vector<const InfostateNode*>&, int);

public:
// -- gain shared ownership of the allocated infostate object
std::shared_ptr< InfostateTree > shared_ptr() { return shared_from_this(); }
std::shared_ptr< const InfostateTree > shared_ptr() const { return shared_from_this(); }

// -- Root accessors ---------------------------------------------------------
const InfostateNode& root() const { return *root_; }
InfostateNode* mutable_root() { return root_.get(); }
Expand Down Expand Up @@ -344,7 +351,7 @@ class InfostateTree final {
// Returns `None` if the sequence is the empty sequence.
absl::optional<DecisionId> DecisionIdForSequence(const SequenceId&) const;
// Returns `None` if the sequence is the empty sequence.
absl::optional<InfostateNode*> DecisionForSequence(const SequenceId&);
absl::optional<InfostateNode*> DecisionForSequence(const SequenceId& sequence_id) const;
// Returns whether the sequence ends with the last action the player can make.
bool IsLeafSequence(const SequenceId&) const;

Expand Down Expand Up @@ -385,13 +392,13 @@ class InfostateTree final {
// Compute best response and value based on gradient from opponents.
// This consumes the gradient vector, as it is used to compute the value.
std::pair<double, SfStrategy> BestResponse(
TreeplexVector<double>&& gradient) const;
TreeplexVector<double> gradient) const;
// Compute best response value based on gradient from opponents over leaves.
// This consumes the gradient vector, as it is used to compute the value.
double BestResponseValue(LeafVector<double>&& gradient) const;
double BestResponseValue(LeafVector<double> gradient) const;

// -- For debugging ----------------------------------------------------------
std::ostream& operator<<(std::ostream& os) const;
friend std::ostream& operator<<(std::ostream& os, const InfostateTree& tree);

private:
const Player acting_player_;
Expand Down Expand Up @@ -550,6 +557,7 @@ class InfostateNode final {
const InfostateNodeType& type() const { return type_; }
size_t depth() const { return depth_; }
bool is_root_node() const { return !parent_; }
bool is_filler_node() const { return infostate_string_ == kFillerInfostate; }
bool has_infostate_string() const {
return infostate_string_ != kFillerInfostate &&
infostate_string_ != kDummyRootNodeInfostate;
Expand Down
3 changes: 3 additions & 0 deletions open_spiel/python/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ endif()
# List of all Python bindings to add to pyspiel.
include_directories (../pybind11_abseil ../../pybind11/include)
set(PYTHON_BINDINGS ${PYTHON_BINDINGS}
pybind11/algorithms_infostate_tree.cc
pybind11/algorithms_infostate_tree.tcc
pybind11/algorithms_infostate_tree.h
pybind11/algorithms_corr_dist.cc
pybind11/algorithms_corr_dist.h
pybind11/algorithms_trajectories.cc
Expand Down
Loading